blob: 0bb6b00ff4461d6d958fbfe091f7ed9f6c4fb905 [file] [log] [blame]
// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-tpu-extract-outside-compilation | FILECHECK_OPTS="" FileCheck %s
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
// Tests that TPU cluster with no outside compilation does not generate parallel_execute.
// CHECK-LABEL: func @no_outside_compilation
func.func @no_outside_compilation() -> tensor<2xi32> {
// CHECK-NOT: "tf_device.parallel_execute"
%0 = "tf_device.cluster"() ({
%1 = "tf.A"() : () -> tensor<2xi32>
%2 = "tf.B"(%1) : (tensor<2xi32>) -> tensor<2xi32>
tf_device.return %2 : tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32>
func.return %0 : tensor<2xi32>
// CHECK-LABEL: func @attribute_outside_of_cluster
func.func @attribute_outside_of_cluster() -> tensor<2xi32> {
// CHECK-NOT: _xla_outside_compilation
%0 = "tf_device.cluster"() ({
%1 = "tf.A"() : () -> tensor<2xi32>
tf_device.return %1 : tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32>
%2 = "tf.B"(%0) {_xla_outside_compilation = "cluster1"} : (tensor<2xi32>) -> tensor<2xi32>
func.return %0 : tensor<2xi32>
// Tests extraction of a single outside compiled cluster with no input or output dependencies.
// CHECK-LABEL: func @nodep_single_outside_compilation
func.func @nodep_single_outside_compilation() -> () {
// CHECK: "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-NEXT: "tf.B"
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
// CHECK: "tf_device.cluster"
// CHECK-NEXT: "tf.A"
// CHECK: device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""
"tf_device.cluster"() ({
"tf.A"() : () -> ()
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
"tf.C"() : () -> ()
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> ()
// Tests extraction of a single outside compiled cluster with multiple ops and no input or output dependecies.
// CHECK-LABEL: func @nodep_single_cluster_multiple_ops_outside_compilation
func.func @nodep_single_cluster_multiple_ops_outside_compilation() -> () {
// CHECK: "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-NEXT: "tf.B"
// CHECK-NEXT: "tf.C"
// CHECK-NEXT: "tf.D"
// CHECK-NOT: _xla_outside_compilation
// CHECK: "tf_device.cluster"
// CHECK-NEXT: "tf.A"
// CHECK-NEXT: "tf.E"
// CHECK: device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""
"tf_device.cluster"() ({
"tf.A"() : () -> ()
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
"tf.C"() {_xla_outside_compilation = "cluster1"} : () -> ()
"tf.D"() {_xla_outside_compilation = "cluster1"} : () -> ()
"tf.E"() : () -> ()
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> ()
// Tests extraction of a multiple outside compiled clusters with no input or output dependecies.
// CHECK-LABEL: func @nodep_multiple_outside_compilation
func.func @nodep_multiple_outside_compilation() -> () {
// CHECK: "tf_device.parallel_execute"
// CHECK: "tf_device.launch"
// CHECK: "tf.B"
// CHECK-NEXT: "tf.D"
// CHECK-NOT "tf_device.launch"
// CHECK: "tf_device.cluster"
"tf_device.cluster"() ({
"tf.A"() : () -> ()
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
"tf.C"() : () -> ()
"tf.D"() {_xla_outside_compilation = "cluster2"} : () -> ()
"tf.E"() : () -> ()
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> ()
// Tests extraction of a single outside compiled cluster with single TPU cluster return.
// CHECK-LABEL: func @single_tpu_return_single_outside_compilation
func.func @single_tpu_return_single_outside_compilation(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "tf.A"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK: "tf.B"
// CHECK-NEXT: tf_device.return
// CHECK: %[[TPU_CLUSTER_OUTPUT:[0-9]*]] = "tf_device.cluster"
// CHECK: tf_device.return
// CHECK: tf_device.return %[[TPU_CLUSTER_OUTPUT]]
// CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]]
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ({
"tf.A"() : () -> ()
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
%3 = "tf.C"() : () -> tensor<2xi32>
tf_device.return %3 : tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32>
tf_device.return %2 : tensor<2xi32>
func.return %1 : tensor<2xi32>
// Tests extraction of a single outside compiled cluster with multiple TPU cluster return.
// CHECK-LABEL: func @multiple_tpu_return_single_outside_compilation
func.func @multiple_tpu_return_single_outside_compilation(%arg0: tensor<2xi32>) -> tensor<3xf32> {
%0 = "tf.A"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:4 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:2 = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK: %[[TPU_CLUSTER_OUTPUT:[0-9]*]]:2 = "tf_device.cluster"
// CHECK: tf_device.return
// CHECK: tf_device.return %[[TPU_CLUSTER_OUTPUT]]
// CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]]
%1:4 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
%2, %3 = "tf_device.cluster"() ({
%4 = "tf.A"() : () -> tensor<3xf32>
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
%5 = "tf.C"() : () -> tensor<2xi32>
tf_device.return %4, %5 : tensor<3xf32>, tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> (tensor<3xf32>, tensor<2xi32>)
tf_device.return %2, %3 : tensor<3xf32>, tensor<2xi32>
func.return %1 : tensor<3xf32>
// Tests extraction of a single outside compiled cluster with single device->host input.
// CHECK-LABEL: func @single_outside_compiled_input_single_outside_compilation
func.func @single_outside_compiled_input_single_outside_compilation(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "tf.A"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:[a-z_0-9]+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_0_args"
// CHECK: "tf.B"(%[[RECV_OUTPUT]])
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: "tf._XlaHostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: host_mlir_module = ""
// CHECK-SAME: send_key = "host_compute_channel_0_args"
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ({
%3 = "tf.A"() : () -> (tensor<2xi32>)
"tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<2xi32>) -> ()
%4 = "tf.C"() : () -> tensor<2xi32>
tf_device.return %4 : tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32>
tf_device.return %2 : tensor<2xi32>
func.return %1 : tensor<2xi32>
// Tests value is added as operand to XlaHostCompute op only if defining op is
// in TPU cluster.
// CHECK-LABEL: func @single_outside_compiled_input_from_outside_device_cluster
func.func @single_outside_compiled_input_from_outside_device_cluster(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "tf.A"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK-NEXT: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK: "tf.B"(%[[A_OUTPUT]])
// CHECK: "tf_device.cluster"
// CHECK-NEXT: "tf.C"()
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
%3 = "tf.A"() : () -> (tensor<2xi32>)
%2 = "tf_device.cluster"() ({
"tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<2xi32>) -> ()
%4 = "tf.C"() : () -> tensor<2xi32>
tf_device.return %4 : tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32>
tf_device.return %2 : tensor<2xi32>
func.return %1 : tensor<2xi32>
// Tests extraction of a single outside compiled cluster with single host->device output.
// CHECK-LABEL: func @single_outside_compiled_output_single_outside_compilation_not_replicated
func.func @single_outside_compiled_output_single_outside_compilation_not_replicated(%arg0: tensor<2xi32>) -> tensor<2xi32> {
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK: %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-NOT: "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"()
// CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]])
// CHECK-SAME: device_ordinal = 0
// CHECK-SAME: key = "host_compute_channel_0_retvals"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"()
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_0_args"
// CHECK: "tf.C"(%[[HOST_OUTPUT]])
%0 = "tf_device.cluster"() ({
%1 = "tf.A"() : () -> (tensor<2xi32>)
%2 = "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> (tensor<2xi32>)
%3 = "tf.C"(%2) : (tensor<2xi32>) -> tensor<2xi32>
tf_device.return %3 : tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32>
func.return %0 : tensor<2xi32>
// CHECK-LABEL: func @single_outside_compiled_output_single_outside_compilation_replicated
func.func @single_outside_compiled_output_single_outside_compilation_replicated(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "tf.A"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"()
// CHECK: "tf._XlaSendFromHostV2"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_0_retvals"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"()
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_0_args"
// CHECK: "tf.C"(%[[HOST_OUTPUT]])
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ({
%3 = "tf.A"() : () -> (tensor<2xi32>)
%4 = "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> (tensor<2xi32>)
%5 = "tf.C"(%4) : (tensor<2xi32>) -> tensor<2xi32>
tf_device.return %5 : tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32>
tf_device.return %2 : tensor<2xi32>
func.return %1 : tensor<2xi32>
// Tests extraction of a single outside compiled cluster host output returned by TPU cluster.
// CHECK-LABEL: func @return_host_output_outside_compilation
func.func @return_host_output_outside_compilation(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "tf.A"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: _xla_has_host_transfer = true
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT]])
// CHECK: "tf._XlaSendFromHostV2"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: _xla_has_host_transfer = true
// CHECK-SAME: key = "host_compute_channel_0_retvals"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
// CHECK: tf_device.return %[[HOST_OUTPUT]]
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ({
%3 = "tf.A"() : () -> (tensor<2xi32>)
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<2xi32>) -> (tensor<2xi32>)
%5 = "tf.C"(%3) : (tensor<2xi32>) -> (tensor<2xi32>)
tf_device.return %4 : tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32>
tf_device.return %2 : tensor<2xi32>
func.return %1 : tensor<2xi32>
// Tests extraction of a single outside compiled cluster with single input/output.
// CHECK-LABEL: func @single_outside_compiled_input_output_single_outside_compilation
func.func @single_outside_compiled_input_output_single_outside_compilation(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "tf.A"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT]])
// CHECK: "tf._XlaSendFromHostV2"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_0_retvals"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
// CHECK: "tf.C"(%[[HOST_OUTPUT]])
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ({
%3 = "tf.A"() : () -> (tensor<2xi32>)
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<2xi32>) -> (tensor<2xi32>)
%5 = "tf.C"(%4) : (tensor<2xi32>) -> tensor<2xi32>
tf_device.return %5 : tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32>
tf_device.return %2 : tensor<2xi32>
func.return %1 : tensor<2xi32>
// Tests host to device communcation is added only if value is used for ops
// that are not outside compiled.
// CHECK-LABEL: func @single_outside_compiled_output_used_for_another_host_op
func.func @single_outside_compiled_output_used_for_another_host_op(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "tf.A"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-NEXT: %[[B_OUTPUT:[0-9]*]] = "tf.B"()
// CHECK-NEXT: "tf.IfRegion"(%[[A_OUTPUT]])
// CHECK-NEXT: "tf.D"(%[[B_OUTPUT]])
// CHECK: "tf_device.cluster"
// CHECK-NEXT: "tf.C"()
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
%3 = "tf.A"() : () -> (tensor<i1>)
%2 = "tf_device.cluster"() ({
%4 = "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> (tensor<2xi32>)
"tf.IfRegion"(%3) ({
"tf.D"(%4) : (tensor<2xi32>) -> ()
"tf.Yield"() : () -> ()
}, {
"tf.Yield"() : () -> ()
}) { _xla_outside_compilation = "cluster1", is_stateless = false} : (tensor<i1>) -> ()
%5 = "tf.C"() : () -> (tensor<2xi32>)
tf_device.return %5 : tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32>
tf_device.return %2 : tensor<2xi32>
func.return %1 : tensor<2xi32>
// Tests extraction of a single outside compiled cluster with multiple input/output.
// CHECK-LABEL: func @multiple_outside_compiled_input_output_single_outside_compilation
func.func @multiple_outside_compiled_input_output_single_outside_compilation(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "tf.A"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: %[[B_OUTPUT:[0-9]*]]:2 = "tf.C"(%[[RECV_OUTPUT]]#0, %[[RECV_OUTPUT]]#1)
// CHECK: "tf._XlaSendFromHostV2"(%[[B_OUTPUT]]#0, %[[B_OUTPUT]]#1, %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_0_retvals"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[HOST_OUTPUT:[0-9]*]]:2 = "tf._XlaHostComputeMlir"(%[[A_OUTPUT]], %[[B_OUTPUT]])
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
// CHECK: "tf.D"(%[[HOST_OUTPUT]]#0)
// CHECK: "tf.E"(%[[HOST_OUTPUT]]#1)
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ({
%3 = "tf.A"() : () -> (tensor<2xi32>)
%4 = "tf.B"() : () -> (tensor<2xi32>)
%5, %6 = "tf.C"(%3, %4) {_xla_outside_compilation = "cluster1"} : (tensor<2xi32>, tensor<2xi32>) -> (tensor<2xi32>, tensor<2xi32>)
%7 = "tf.D"(%5) : (tensor<2xi32>) -> tensor<2xi32>
%8 = "tf.E"(%6) : (tensor<2xi32>) -> tensor<2xi32>
tf_device.return %8 : tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32>
tf_device.return %2 : tensor<2xi32>
func.return %1 : tensor<2xi32>
// Tests extraction of a multiple outside compiled clusters with input/output.
// CHECK-LABEL: func @outside_compiled_input_output_multiple_outside_compilation
func.func @outside_compiled_input_output_multiple_outside_compilation(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "tf.A"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: %[[RECV_OUTPUT1:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT1]])
// CHECK: %[[RECV_OUTPUT2:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[RECV_OUTPUT2]])
// CHECK: "tf._XlaSendFromHostV2"(%[[D_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_1_retvals"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[HOST_OUTPUT1:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[HOST_OUTPUT1]])
// CHECK: %[[HOST_OUTPUT2:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[C_OUTPUT]])
// CHECK-SAME: recv_key = "host_compute_channel_1_retvals"
// CHECK: "tf.E"(%[[HOST_OUTPUT2]])
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ({
%3 = "tf.A"() : () -> (tensor<2xi32>)
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<2xi32>) -> (tensor<2xi32>)
%5 = "tf.C"(%4) : (tensor<2xi32>) -> (tensor<2xi32>)
%6 = "tf.D"(%5) {_xla_outside_compilation = "cluster2"} : (tensor<2xi32>) -> (tensor<2xi32>)
%7 = "tf.E"(%6) : (tensor<2xi32>) -> tensor<2xi32>
tf_device.return %7 : tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32>
tf_device.return %2 : tensor<2xi32>
func.return %1 : tensor<2xi32>
// Tests extraction of a single outside compiled cluster with arg input and single device->host input.
// CHECK-LABEL: func @mixed_input_single_outside_compilation
func.func @mixed_input_single_outside_compilation(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "tf.A"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_0_args"
// CHECK: "tf.B"(%arg0, %[[RECV_OUTPUT]])
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: "tf._XlaHostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: send_key = "host_compute_channel_0_args"
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ({
%3 = "tf.A"() : () -> (tensor<2xi32>)
"tf.B"(%arg0, %3) {_xla_outside_compilation = "cluster1"} : (tensor<2xi32>, tensor<2xi32>) -> ()
%4 = "tf.C"() : () -> tensor<2xi32>
tf_device.return %4 : tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32>
tf_device.return %2 : tensor<2xi32>
func.return %1 : tensor<2xi32>
// Tests extraction of a multiple outside compiled clusters with single device->host input.
// CHECK-LABEL: func @single_outside_compiled_input_multiple_outside_compilation
func.func @single_outside_compiled_input_multiple_outside_compilation(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "tf.A"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: %[[RECV_OUTPUT_1:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_0_args"
// CHECK: "tf.B"(%[[RECV_OUTPUT_1]])
// CHECK: %[[RECV_OUTPUT_2:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_1_args"
// CHECK: "tf.D"(%[[RECV_OUTPUT_2]])
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: "tf._XlaHostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: send_key = "host_compute_channel_0_args"
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"
// CHECK: "tf._XlaHostComputeMlir"(%[[C_OUTPUT]])
// CHECK-SAME: send_key = "host_compute_channel_1_args"
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ({
%3 = "tf.A"() : () -> (tensor<2xi32>)
"tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<2xi32>) -> ()
%4 = "tf.C"() : () -> tensor<2xi32>
"tf.D"(%4) {_xla_outside_compilation = "cluster2"} : (tensor<2xi32>) -> ()
tf_device.return %4 : tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32>
tf_device.return %2 : tensor<2xi32>
func.return %1 : tensor<2xi32>
// Tests extraction of a single outside compiled cluster with multiple device->host inputs.
// CHECK-LABEL: func @multiple_outside_compiled_inputs_single_outside_compilation
func.func @multiple_outside_compiled_inputs_single_outside_compilation(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "tf.A"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: %[[RECV_OUTPUT_1:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_0_args"
// CHECK: "tf.C"(%[[RECV_OUTPUT_1]])
// CHECK: %[[RECV_OUTPUT_2:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_1_args"
// CHECK: "tf.D"(%[[RECV_OUTPUT_2]], %[[RECV_OUTPUT_1]])
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: "tf._XlaHostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: send_key = "host_compute_channel_0_args"
// CHECK: "tf._XlaHostComputeMlir"(%[[B_OUTPUT]])
// CHECK-SAME: send_key = "host_compute_channel_1_args"
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ({
%3 = "tf.A"() : () -> (tensor<2xi32>)
%4 = "tf.B"() : () -> (tensor<2xi32>)
"tf.C"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<2xi32>) -> ()
"tf.D"(%4, %3) {_xla_outside_compilation = "cluster1"} : (tensor<2xi32>, tensor<2xi32>) -> ()
%5 = "tf.E"() : () -> tensor<2xi32>
tf_device.return %5 : tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32>
tf_device.return %2 : tensor<2xi32>
func.return %1 : tensor<2xi32>
// Tests only directly used results of tpu cluster are remapped with
// parallel_execute.
// CHECK-LABEL: func @remapped_results
func.func @remapped_results(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "tf.A"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:2 = "tf_device.parallel_execute"
// CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]]#1 : tensor<2xi32>
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
%2:2 = "tf_device.cluster"() ({
%3 = "tf.A"() : () -> (tensor<2xi32>)
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<2xi32>) -> (tensor<2xi32>)
%5:2 = "tf.C"(%4) : (tensor<2xi32>) -> (tensor<2xi32>, tensor<2xi32>)
tf_device.return %5#0, %5#1 : tensor<2xi32>, tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> (tensor<2xi32>, tensor<2xi32>)
tf_device.return %2#1 : tensor<2xi32>
func.return %1 : tensor<2xi32>
// Tests extraction of a single outside compiled cluster inside a tf.IfRegion op.
// CHECK-LABEL: func @outside_compiled_ops_inside_tf_if
func.func @outside_compiled_ops_inside_tf_if(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "tf.A"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-SAME: key = "if_predicate_channel_1"
// CHECK-NEXT: %[[ARG_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_0_args"
// CHECK: "tf.D"(%[[ARG_RECV_OUTPUT]]#0, %[[ARG_RECV_OUTPUT]]#1)
// CHECK-NOT: "tf._XlaSendFromHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: "tf.Yield"() : () -> ()
// CHECK: _else_func_name = "test_else_name"
// CHECK-SAME _then_func_name = "test_then_name"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
// CHECK: "tf._XlaHostComputeMlir"
// CHECK-SAME: key = "if_predicate_channel_1"
// CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]])
// CHECK: "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[A_OUTPUT]])
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_0_args"
// CHECK-NEXT: "tf.Yield"() : () -> ()
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ({
%3 = "tf.A"() : () -> (tensor<2xi32>)
%4 = "tf.B"() : () -> (tensor<2xi32>)
%6 = "tf.G"() : () -> (tensor<i1>)
"tf.IfRegion"(%6) ({
"tf.D"(%4, %3) {_xla_outside_compilation = "cluster1"} : (tensor<2xi32>, tensor<2xi32>) -> ()
"tf.Yield"() : () -> ()
}, {
"tf.Yield"() : () -> ()
}) { is_stateless = false, _then_func_name = "test_then_name", _else_func_name = "test_else_name"} : (tensor<i1>) -> ()
%5 = "tf.E"() : () -> tensor<2xi32>
tf_device.return %5 : tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32>
tf_device.return %2 : tensor<2xi32>
func.return %1 : tensor<2xi32>
// Tests extraction of a single outside compiled cluster inside a tf.IfRegion op with dynamic shape.
// CHECK-LABEL: func @outside_compiled_ops_inside_tf_if_dynamic_shape
func.func @outside_compiled_ops_inside_tf_if_dynamic_shape(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "tf.A"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-SAME: key = "if_predicate_channel_1"
// CHECK-NEXT: %[[ARG_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_0_args"
// CHECK: "tf.D"(%[[ARG_RECV_OUTPUT]]#0, %[[ARG_RECV_OUTPUT]]#1)
// CHECK-NOT: "tf._XlaSendFromHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: "tf.Yield"() : () -> ()
// CHECK: _else_func_name = "test_else_name"
// CHECK-SAME _then_func_name = "test_then_name"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
// CHECK: "tf._XlaHostComputeMlir"
// CHECK-SAME: key = "if_predicate_channel_1"
// CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]])
// CHECK: "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[A_OUTPUT]])
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_0_args"
// CHECK-NEXT: "tf.Yield"() : () -> ()
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ({
%3 = "tf.A"() : () -> (tensor<2xi32>)
%4 = "tf.B"() : () -> (tensor<?xi32>)
%6 = "tf.G"() : () -> (tensor<i1>)
"tf.IfRegion"(%6) ({
"tf.D"(%4, %3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>, tensor<2xi32>) -> ()
"tf.Yield"() : () -> ()
}, {
"tf.Yield"() : () -> ()
}) { is_stateless = false, _then_func_name = "test_then_name", _else_func_name = "test_else_name"} : (tensor<i1>) -> ()
%5 = "tf.E"() : () -> tensor<2xi32>
tf_device.return %5 : tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32>
tf_device.return %2 : tensor<2xi32>
func.return %1 : tensor<2xi32>
// Ensures that separate send/recvs are added for values that are used by ops inside of multiple IfRegions.
// CHECK-LABEL: func @outside_compiled_ops_multiple_tf_if
func.func @outside_compiled_ops_multiple_tf_if(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "tf.A"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-SAME: key = "if_predicate_channel_2"
// CHECK: %[[ARG_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_0_args"
// CHECK: "tf.D"(%[[ARG_RECV_OUTPUT]])
// CHECK: %[[ARG_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_1_args"
// CHECK: "tf.F"(%[[ARG_RECV_OUTPUT]])
// CHECK: "tf_device.cluster"
// CHECK: "tf._XlaHostComputeMlir"
// CHECK-SAME: key = "if_predicate_channel_2"
// CHECK-NEXT: tf.IfRegion"
// CHECK: "tf._XlaHostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_0_args"
// CHECK: tf.IfRegion"
// CHECK: "tf._XlaHostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: recv_key = "host_compute_channel_1_retvals"
// CHECK-SAME: send_key = "host_compute_channel_1_args"
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ({
%3 = "tf.A"() : () -> (tensor<2xi32>)
%6 = "tf.G"() : () -> (tensor<i1>)
"tf.IfRegion"(%6) ({
"tf.D"(%3) {_xla_outside_compilation = "auto"} : (tensor<2xi32>) -> ()
"tf.Yield"() : () -> ()
}, {
"tf.Yield"() : () -> ()
}) { is_stateless = false} : (tensor<i1>) -> ()
"tf.IfRegion"(%6) ({
"tf.F"(%3) {_xla_outside_compilation = "auto"} : (tensor<2xi32>) -> ()
"tf.Yield"() : () -> ()
}, {
"tf.Yield"() : () -> ()
}) { is_stateless = false} : (tensor<i1>) -> ()
tf_device.return %3 : tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32>
tf_device.return %2 : tensor<2xi32>
func.return %1 : tensor<2xi32>
// Tests extraction of an outside compiled tf.IfRegion op where the entirety
// of tf.IfRegion op is outside compiled
// CHECK-LABEL: func @outside_compiled_tf_if
func.func @outside_compiled_tf_if(%arg0: tensor<2xi32>) -> tensor<2xi32> {
// CHECK: %[[A_OUT:[0-9]*]] = "tf.A"
// CHECK: %[[F_OUT:[0-9]*]] = "tf.F"
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-NEXT: %[[RECV_OUTPUT:[0-9]*]]:3 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_0_args"
// CHECK-SAME: (tensor<3x!tf_type.string>, tensor<i64>) -> (tensor<2xi32>, tensor<2xi32>, tensor<i1>)
// CHECK-NEXT: tf.IfRegion"(%[[RECV_OUTPUT]]#2)
// CHECK: "tf.D"(%[[RECV_OUTPUT]]#0, %[[RECV_OUTPUT]]#1, %[[F_OUT]])
// CHECK-NOT: "tf._XlaSendFromHostV2"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
// CHECK: "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[A_OUTPUT]], %[[G_OUTPUT]])
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_0_args"
%0 = "tf.A"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
%7 = "tf.F"() : () -> tensor<2xi32>
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ({
%3 = "tf.A"() : () -> (tensor<2xi32>)
%4 = "tf.B"() : () -> (tensor<2xi32>)
%6 = "tf.G"() : () -> (tensor<i1>)
"tf.IfRegion"(%6) ({
"tf.D"(%4, %3, %7) {} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> ()
"tf.Yield"() : () -> ()
}, {
"tf.Yield"() : () -> ()
}) {_xla_outside_compilation = "cluster1", is_stateless = false} : (tensor<i1>) -> ()
%5 = "tf.E"() : () -> tensor<2xi32>
tf_device.return %5 : tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32>
tf_device.return %2 : tensor<2xi32>
func.return %1 : tensor<2xi32>
// Tests extraction of an outside compiled tf.IfRegion op where the entirety
// of tf.IfRegion op is outside compiled and wrapped inside another
// tf.IfRegion op
// CHECK-LABEL: func @outside_compiled_tf_if_nested
func.func @outside_compiled_tf_if_nested(%arg0: tensor<2xi32>) -> tensor<2xi32> {
// CHECK: %[[A_OUT:[0-9]*]] = "tf.A"
// CHECK: %[[F_OUT:[0-9]*]] = "tf.F"
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-SAME: key = "if_predicate_channel_1"
// CHECK-SAME: (tensor<3x!tf_type.string>, tensor<i64>) -> tensor<i1>
// CHECK-NEXT: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_0_args"
// CHECK-SAME: (tensor<3x!tf_type.string>, tensor<i64>) -> (tensor<2xi32>, tensor<i1>)
// CHECK-NEXT: tf.IfRegion"(%[[RECV_OUTPUT]]#1)
// CHECK-NEXT: "tf.H"(%[[RECV_OUTPUT]]#0, %[[F_OUT]])
// CHECK: "tf.Yield"() : () -> ()
// CHECK: "tf.Yield"() : () -> ()
// CHECK-NOT: "tf._XlaSendFromHostV2"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
// CHECK: "tf._XlaHostComputeMlir"(%[[G_OUTPUT]])
// CHECK-SAME: key = "if_predicate_channel_1"
// CHECK-SAME: (tensor<i1>) -> ()
// CHECK-NEXT: "tf.IfRegion"(%[[G_OUTPUT]])
// CHECK: %[[D_OUT:[0-9]*]] = "tf.D"
// CHECK-NEXT: %[[F_OUT:[0-9]*]] = "tf.F"
// CHECK: "tf._XlaHostComputeMlir"(%[[D_OUT]], %[[F_OUT]])
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_0_args"
// CHECK: "tf.Yield"() : () -> ()
// CHECK: "tf.Yield"() : () -> ()
%0 = "tf.A"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
%7 = "tf.F"() : () -> tensor<2xi32>
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ({
%3 = "tf.A"() : () -> (tensor<2xi32>)
%4 = "tf.B"() : () -> (tensor<2xi32>)
%6 = "tf.G"() : () -> (tensor<i1>)
"tf.IfRegion"(%6) ({
%8 = "tf.D"(%4, %3, %7) {} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> (tensor<2xi32>)
%9 = "tf.F"(%4) {} : (tensor<2xi32>) -> (tensor<i1>)
"tf.IfRegion"(%9) ({
"tf.H"(%8, %7) : (tensor<2xi32>, tensor<2xi32>) -> ()
"tf.Yield"() : () -> ()
}, {
"tf.Yield"() : () -> ()
}) {_xla_outside_compilation = "cluster1", is_stateless = false} : (tensor<i1>) -> ()
"tf.Yield"() : () -> ()
}, {
"tf.Yield"() : () -> ()
}) {is_stateless = false} : (tensor<i1>) -> ()
%5 = "tf.E"() : () -> tensor<2xi32>
tf_device.return %5 : tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32>
tf_device.return %2 : tensor<2xi32>
func.return %1 : tensor<2xi32>
// Tests extraction of a single outside compiled cluster inside a tf.IfRegion
// op with return values.
// CHECK-LABEL: func @outside_compiled_ops_inside_tf_if_with_return_values
func.func @outside_compiled_ops_inside_tf_if_with_return_values(
%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "tf.A"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-SAME: key = "if_predicate_channel_1"
// CHECK-NEXT: %[[ARG_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_0_args"
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[ARG_RECV_OUTPUT]]#0, %[[ARG_RECV_OUTPUT]]#1)
// CHECK: "tf._XlaSendFromHostV2"(%[[D_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_0_retvals"
// CHECK-NEXT: "tf.Yield"() : () -> ()
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
// CHECK: "tf._XlaHostComputeMlir"(%6)
// CHECK-SAME: key = "if_predicate_channel_1"
// CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]])
// CHECK: %[[HOST_COMPUTE_OUT:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[A_OUTPUT]])
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_0_args"
// CHECK-NEXT: "tf.Yield"(%[[HOST_COMPUTE_OUT]])
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ({
%3 = "tf.A"() : () -> (tensor<2xi32>)
%4 = "tf.B"() : () -> (tensor<2xi32>)
%6 = "tf.G"() : () -> (tensor<i1>)
"tf.IfRegion"(%6) ({
%7 = "tf.D"(%4, %3) {_xla_outside_compilation = "cluster1"} : (tensor<2xi32>, tensor<2xi32>) -> (tensor<2xi32>)
"tf.Yield"(%7) : (tensor<2xi32>) -> ()
}, {
%8 = "tf.F"() : () -> (tensor<2xi32>)
"tf.Yield"(%8) : (tensor<2xi32>) -> ()
}) { is_stateless = false} : (tensor<i1>) -> (tensor<2xi32>)
%5 = "tf.E"() : () -> tensor<2xi32>
tf_device.return %5 : tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32>
tf_device.return %2 : tensor<2xi32>
func.return %1 : tensor<2xi32>
// Tests extraction of a single outside compiled cluster inside a tf.IfRegion op without external inputs/outputs
// CHECK-LABEL: func @outside_compiled_ops_inside_tf_if_without_input_outputs
func.func @outside_compiled_ops_inside_tf_if_without_input_outputs(
%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "tf.A"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-SAME: key = "if_predicate_channel_0"
// CHECK: "tf.D"
// CHECK-NEXT: "tf.Yield"() : () -> ()
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
// CHECK: "tf._XlaHostComputeMlir"(%6)
// CHECK-SAME: key = "if_predicate_channel_0"
// CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]])
// CHECK-NEXT: "tf.Yield"() : () -> ()
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ({
%3 = "tf.A"() : () -> (tensor<2xi32>)
%4 = "tf.B"() : () -> (tensor<2xi32>)
%6 = "tf.G"() : () -> (tensor<i1>)
"tf.IfRegion"(%6) ({
"tf.D"() {_xla_outside_compilation = "cluster1"} : () -> ()
"tf.Yield"() : () -> ()
}, {
"tf.Yield"() : () -> ()
}) { is_stateless = false} : (tensor<i1>) -> ()
%5 = "tf.E"() : () -> tensor<2xi32>
tf_device.return %5 : tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32>
tf_device.return %2 : tensor<2xi32>
func.return %1 : tensor<2xi32>
// Tests extraction of a single outside compiled cluster inside a nested
// tf.IfRegion op.
// CHECK-LABEL: func @outside_compiled_ops_inside_nested_if
func.func @outside_compiled_ops_inside_nested_if(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "tf.A"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-SAME: key = "if_predicate_channel_2"
// CHECK-SAME: key = "if_predicate_channel_1"
// CHECK-NEXT: "tf.Yield"() : () -> ()
// CHECK: %[[ARG_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_0_args"
// CHECK: "tf.D"(%[[ARG_RECV_OUTPUT]])
// CHECK-NOT: "tf._XlaSendFromHostV2"
// CHECK-NEXT: "tf.Yield"() : () -> ()
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
// CHECK: "tf._XlaHostComputeMlir"(%[[G_OUTPUT]])
// CHECK-SAME key = "if_predicate_channel_2"
// CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]])
// CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H"(%[[B_OUTPUT]])
// CHECK: "tf._XlaHostComputeMlir"(%[[H_OUTPUT]])
// CHECK-SAME: key = "if_predicate_channel_1"
// CHECK-NEXT: tf.IfRegion"(%[[H_OUTPUT]])
// CHECK-NEXT: "tf.Yield"() : () -> ()
// CHECK: %[[I_OUTPUT:[0-9]*]] = "tf.I"(%[[H_OUTPUT]])
// CHECK: "tf._XlaHostComputeMlir"(%[[I_OUTPUT]])
// CHECK-NEXT: "tf.Yield"() : () -> ()
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ({
%3 = "tf.A"() : () -> (tensor<2xi32>)
%4 = "tf.B"() : () -> (tensor<2xi32>)
%6 = "tf.G"() : () -> (tensor<i1>)
"tf.IfRegion"(%6) ({
%7 = "tf.H"(%4) : (tensor<2xi32>) -> (tensor<i1>)
"tf.Yield"() : () -> ()
%8 = "tf.I"(%7) : (tensor<i1>) -> (tensor<2xi32>)
"tf.D"(%8) {_xla_outside_compilation = "cluster1"} : (tensor<2xi32>) -> ()
"tf.Yield"() : () -> ()
}) { is_stateless = false} : (tensor<i1>) -> ()
"tf.Yield"() : () -> ()
}, {
"tf.Yield"() : () -> ()
}) { is_stateless = false} : (tensor<i1>) -> ()
%5 = "tf.E"() : () -> tensor<2xi32>
tf_device.return %5 : tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32>
tf_device.return %2 : tensor<2xi32>
func.return %1 : tensor<2xi32>
// Tests extraction of a single outside compiled cluster inside a tf.WhileRegion op body.
// CHECK-LABEL: func @outside_compiled_ops_inside_tf_while_body
func.func @outside_compiled_ops_inside_tf_while_body(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "tf.A"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-NEXT: tf.WhileRegion"
// CHECK-NEXT: %[[COND_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "while_condition_channel_0"
// CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT]])
// CHECK: %[[BODY_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"
// CHECK: "tf._XlaSendFromHostV2"(%[[D_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-NEXT: "tf.Yield"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
// CHECK-NEXT: tf.WhileRegion"(%[[B_OUTPUT]], %[[A_OUTPUT]])
// CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H"
// CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]])
// CHECK-NEXT: "tf.Yield"(%[[H_OUTPUT]])
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"
// CHECK-NEXT: %[[HOST_COMPUTE_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ({
%3 = "tf.A"() : () -> (tensor<3xf32>)
%4 = "tf.B"() : () -> (tensor<i32>)
%6 = "tf.G"() : () -> (tensor<i1>)
"tf.WhileRegion"(%4, %3) ({
^bb0(%arg1: tensor<i32>, %arg2: tensor<3xf32>):
%7 = "tf.H"(%arg1) : (tensor<i32>) -> tensor<i1>
"tf.Yield"(%7) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<i32>, %arg2: tensor<3xf32>):
%8 = "tf.C"(%arg1) : (tensor<i32>) -> tensor<i32>
%9 = "tf.D"(%arg1, %arg2) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<3xf32>) -> tensor<3xf32>
"tf.Yield"(%8, %9) : (tensor<i32>, tensor<3xf32>) -> ()
}) { is_stateless = false} : (tensor<i32>, tensor<3xf32>) -> (tensor<i32>, tensor<3xf32>)
%5 = "tf.E"() : () -> tensor<2xi32>
tf_device.return %5 : tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32>
tf_device.return %2 : tensor<2xi32>
func.return %1 : tensor<2xi32>
// Tests extraction of a single outside compiled cluster inside a tf.WhileRegion op cond.
// CHECK-LABEL: func @outside_compiled_ops_inside_tf_while_cond
func.func @outside_compiled_ops_inside_tf_while_cond(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "tf.A"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-NEXT: tf.WhileRegion"
// CHECK-NEXT: %[[COND_RECV_OUTPUT1:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-NEXT: %[[I_OUTPUT:[0-9]*]] = "tf.I"(%[[COND_RECV_OUTPUT1]]#0, %[[COND_RECV_OUTPUT1]]#1)
// CHECK-NEXT: "tf._XlaSendFromHostV2"(%[[I_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-NEXT: %[[COND_RECV_OUTPUT2:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "while_condition_channel_0"
// CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT2]])
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
// CHECK-NEXT: tf.WhileRegion"(%[[B_OUTPUT]], %[[A_OUTPUT]])
// CHECK: %[[HOST_COMPUTE_OUTPUT:.+]] = "tf._XlaHostComputeMlir"
// CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H"(%[[HOST_COMPUTE_OUTPUT]])
// CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]])
// CHECK-NEXT: "tf.Yield"(%[[H_OUTPUT]])
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"
// CHECK-NEXT: %[[D_OUTPUT:.+]] = "tf.D"
// CHECK-NEXT: "tf.Yield"(%[[C_OUTPUT]], %[[D_OUTPUT]])
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ({
%3 = "tf.A"() : () -> (tensor<3xf32>)
%4 = "tf.B"() : () -> (tensor<i32>)
%6 = "tf.G"() : () -> (tensor<i1>)
"tf.WhileRegion"(%4, %3) ({
^bb0(%arg1: tensor<i32>, %arg2: tensor<3xf32>):
%7 = "tf.I"(%arg1, %arg2) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<3xf32>) -> tensor<i32>
%8 = "tf.H"(%7) : (tensor<i32>) -> tensor<i1>
"tf.Yield"(%8) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<i32>, %arg2: tensor<3xf32>):
%7 = "tf.C"(%arg1) : (tensor<i32>) -> tensor<i32>
%8 = "tf.D"(%arg1, %arg2) : (tensor<i32>, tensor<3xf32>) -> tensor<3xf32>
"tf.Yield"(%7, %8) : (tensor<i32>, tensor<3xf32>) -> ()
}) { is_stateless = false} : (tensor<i32>, tensor<3xf32>) -> (tensor<i32>, tensor<3xf32>)
%5 = "tf.E"() : () -> tensor<2xi32>
tf_device.return %5 : tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32>
tf_device.return %2 : tensor<2xi32>
func.return %1 : tensor<2xi32>
// Tests extraction of a single outside compiled cluster inside a tf.WhileRegion op cond and body.
// CHECK-LABEL: func @outside_compiled_ops_inside_tf_while_cond_body
func.func @outside_compiled_ops_inside_tf_while_cond_body(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "tf.A"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-NEXT: tf.WhileRegion"
// CHECK-NEXT: %[[COND_RECV_OUTPUT1:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-NEXT: %[[I_OUTPUT:[0-9]*]] = "tf.I"(%[[COND_RECV_OUTPUT1]]#0, %[[COND_RECV_OUTPUT1]]#1)
// CHECK-NEXT: "tf._XlaSendFromHostV2"(%[[I_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-NEXT: %[[COND_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "while_condition_channel_0"
// CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT]])
// CHECK: %[[BODY_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"
// CHECK: "tf._XlaSendFromHostV2"(%[[D_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-NEXT: "tf.Yield"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
// CHECK-NEXT: tf.WhileRegion"(%[[B_OUTPUT]], %[[A_OUTPUT]])
// CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H"
// CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]])
// CHECK-NEXT: "tf.Yield"(%[[H_OUTPUT]])
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"
// CHECK: %[[HOST_COMPUTE_OUTPUT:.+]] = "tf._XlaHostComputeMlir"
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ({
%3 = "tf.A"() : () -> (tensor<3xf32>)
%4 = "tf.B"() : () -> (tensor<i32>)
%6 = "tf.G"() : () -> (tensor<i1>)
"tf.WhileRegion"(%4, %3) ({
^bb0(%arg1: tensor<i32>, %arg2: tensor<3xf32>):
%7 = "tf.I"(%arg1, %arg2) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<3xf32>) -> tensor<i32>
%8 = "tf.H"(%7) : (tensor<i32>) -> tensor<i1>
"tf.Yield"(%8) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<i32>, %arg2: tensor<3xf32>):
%7 = "tf.C"(%arg1) : (tensor<i32>) -> tensor<i32>
%8 = "tf.D"(%arg1, %arg2) {_xla_outside_compilation = "cluster2"} : (tensor<i32>, tensor<3xf32>) -> tensor<3xf32>
"tf.Yield"(%7, %8) : (tensor<i32>, tensor<3xf32>) -> ()
}) { is_stateless = false} : (tensor<i32>, tensor<3xf32>) -> (tensor<i32>, tensor<3xf32>)
%5 = "tf.E"() : () -> tensor<2xi32>
tf_device.return %5 : tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32>
tf_device.return %2 : tensor<2xi32>
func.return %1 : tensor<2xi32>
// Tests extraction of a single outside compiled cluster inside a tf.IfRegion op
// nested in a tf.WhileRegion.
// CHECK-LABEL: func @outside_compiled_ops_inside_tf_while_if
func.func @outside_compiled_ops_inside_tf_while_if(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "tf.A"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-NEXT: tf.WhileRegion"
// CHECK-NEXT: %[[COND_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "while_condition_channel_1"
// CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT]])
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"
// CHECK: "tf._XlaSendFromHostV2"(%[[D_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-NEXT: "tf.Yield"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
// CHECK-NEXT: tf.WhileRegion"(%[[B_OUTPUT]], %[[A_OUTPUT]])
// CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H"
// CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]])
// CHECK-NEXT: "tf.Yield"(%[[H_OUTPUT]])
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"
// CHECK-NEXT: "tf._XlaHostComputeMlir"(%[[G_OUTPUT]])
// CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]])
// CHECK-NEXT: %[[HOST_COMPUTE_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ({
%3 = "tf.A"() : () -> (tensor<3xf32>)
%4 = "tf.B"() : () -> (tensor<i32>)
%6 = "tf.G"() : () -> (tensor<i1>)
"tf.WhileRegion"(%4, %3) ({
^bb0(%arg1: tensor<i32>, %arg2: tensor<3xf32>):
%7 = "tf.H"(%arg1) : (tensor<i32>) -> tensor<i1>
"tf.Yield"(%7) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<i32>, %arg2: tensor<3xf32>):
%8 = "tf.C"(%arg1) : (tensor<i32>) -> tensor<i32>
%10 = "tf.IfRegion"(%6) ({
%9 = "tf.D"() {_xla_outside_compilation = "cluster1"} : () -> tensor<3xf32>
"tf.Yield"(%9) : (tensor<3xf32>) -> ()
}, {
"tf.Yield"(%arg2) : (tensor<3xf32>) -> ()
}) { is_stateless = false} : (tensor<i1>) -> tensor<3xf32>
"tf.Yield"(%8, %10) : (tensor<i32>, tensor<3xf32>) -> ()
}) { is_stateless = false} : (tensor<i32>, tensor<3xf32>) -> (tensor<i32>, tensor<3xf32>)
%5 = "tf.E"() : () -> tensor<2xi32>
tf_device.return %5 : tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32>
tf_device.return %2 : tensor<2xi32>
func.return %1 : tensor<2xi32>
// Tests extraction of an outside compiled tf.IfRegion op where the entirety
// of tf.IfRegion op is outside compiled with a nested tf.WhileRegion op.
// CHECK-LABEL: func @outside_compiled_tf_if_nested_while
func.func @outside_compiled_tf_if_nested_while(%arg0: tensor<2xi32>) -> tensor<2xi32> {
// CHECK: %[[A_OUT:[0-9]*]] = "tf.A"
// CHECK: %[[F_OUT:[0-9]*]] = "tf.F"
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-NEXT: %[[RECV_OUTPUT:[0-9]*]]:3 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_0_args"
// CHECK-SAME: (tensor<3x!tf_type.string>, tensor<i64>) -> (tensor<2xi32>, tensor<2xi32>, tensor<i1>)
// CHECK-NEXT: tf.IfRegion"(%[[RECV_OUTPUT]]#2)
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[RECV_OUTPUT]]#0, %[[RECV_OUTPUT]]#1, %[[F_OUT]])
// CHECK-NEXT: %[[J_OUTPUT:[0-9]*]] = "tf.J"
// CHECK-NEXT: %[[K_OUTPUT:[0-9]*]] = "tf.K"
// CHECK-NEXT: tf.WhileRegion"(%[[J_OUTPUT]], %[[D_OUTPUT]])
// CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H"(%[[K_OUTPUT]])
// CHECK-NOT: "tf._XlaSendFromHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
// CHECK: "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[A_OUTPUT]], %[[G_OUTPUT]])
// CHECK-SAME: recv_key = "host_compute_channel_0_retvals"
// CHECK-SAME: send_key = "host_compute_channel_0_args"
%0 = "tf.A"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
%7 = "tf.F"() : () -> tensor<2xi32>
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ({
%3 = "tf.A"() : () -> (tensor<2xi32>)
%4 = "tf.B"() : () -> (tensor<2xi32>)
%6 = "tf.G"() : () -> (tensor<i1>)
"tf.IfRegion"(%6) ({
%8 = "tf.D"(%4, %3, %7) {} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> (tensor<3xf32>)
%9 = "tf.J"() : () -> (tensor<i32>)
%10 = "tf.K"() : () -> (tensor<i32>)
"tf.WhileRegion"(%9, %8) ({
^bb0(%arg1: tensor<i32>, %arg2: tensor<3xf32>):
%11 = "tf.I"(%arg1, %arg2) : (tensor<i32>, tensor<3xf32>) -> tensor<i32>
%12 = "tf.H"(%10) : (tensor<i32>) -> tensor<i1>
"tf.Yield"(%12) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<i32>, %arg2: tensor<3xf32>):
%11 = "tf.C"(%arg1) : (tensor<i32>) -> tensor<i32>
%12 = "tf.D"(%arg1, %arg2) : (tensor<i32>, tensor<3xf32>) -> tensor<3xf32>
"tf.Yield"(%11, %12) : (tensor<i32>, tensor<3xf32>) -> ()
}) { is_stateless = false} : (tensor<i32>, tensor<3xf32>) -> (tensor<i32>, tensor<3xf32>)
"tf.Yield"() : () -> ()
}, {
"tf.Yield"() : () -> ()
}) {_xla_outside_compilation = "cluster1", is_stateless = false} : (tensor<i1>) -> ()
%5 = "tf.E"() : () -> tensor<2xi32>
tf_device.return %5 : tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32>
tf_device.return %2 : tensor<2xi32>
func.return %1 : tensor<2xi32>
// Tests extraction of an outside compiled tf.WhileRegion where the entire
// tf.WhileRegion op is outside compiled with a nested tf.IfRegion.
// CHECK-LABEL: func @outside_compiled_ops_tf_while_nested_if
func.func @outside_compiled_ops_tf_while_nested_if(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "tf.A"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-NEXT: %[[HOST_RECV_OUTPUT:[0-9]*]]:3 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: "tf.WhileRegion"(%[[HOST_RECV_OUTPUT]]#1, %[[HOST_RECV_OUTPUT]]#2)
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"
// CHECK: "tf.IfRegion"(%[[HOST_RECV_OUTPUT]]#0)
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[C_OUTPUT]])
// CHECK-NEXT: "tf.Yield"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
// CHECK: "tf._XlaHostComputeMlir"(%[[G_OUTPUT]], %[[B_OUTPUT]], %[[A_OUTPUT]])
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ({
%3 = "tf.A"() : () -> (tensor<3xf32>)
%4 = "tf.B"() : () -> (tensor<i32>)
%6 = "tf.G"() : () -> (tensor<i1>)
"tf.WhileRegion"(%4, %3) ({
^bb0(%arg1: tensor<i32>, %arg2: tensor<3xf32>):
%7 = "tf.H"(%arg1) : (tensor<i32>) -> tensor<i1>
"tf.Yield"(%7) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<i32>, %arg2: tensor<3xf32>):
%8 = "tf.C"(%arg1) : (tensor<i32>) -> tensor<i32>
%10 = "tf.IfRegion"(%6) ({
%9 = "tf.D"(%8) : (tensor<i32>) -> tensor<3xf32>
"tf.Yield"(%9) : (tensor<3xf32>) -> ()
}, {
"tf.Yield"(%arg2) : (tensor<3xf32>) -> ()
}) { is_stateless = false} : (tensor<i1>) -> tensor<3xf32>
"tf.Yield"(%8, %10) : (tensor<i32>, tensor<3xf32>) -> ()
}) {_xla_outside_compilation = "cluster1", is_stateless = false} : (tensor<i32>, tensor<3xf32>) -> (tensor<i32>, tensor<3xf32>)
%5 = "tf.E"() : () -> tensor<2xi32>
tf_device.return %5 : tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32>
tf_device.return %2 : tensor<2xi32>
func.return %1 : tensor<2xi32>
// Tests extraction of an outside compiled cluster that contains ops wrapped
// inside multiple regions of nested tf.IfRegion and tf.WhileRegion.
// CHECK-LABEL: func @outside_compiled_ops_with_multiple_region_single_cluster
func.func @outside_compiled_ops_with_multiple_region_single_cluster(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "tf.A"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"
// CHECK-NEXT: "tf._XlaSendFromHostV2"(%[[B_OUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-NEXT: "tf.WhileRegion"()
// CHECK-NEXT: %[[WHILE_COND:.*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-NEXT: "tf.Yield"(%[[WHILE_COND]])
// CHECK: %[[C_OUT:.*]] = "tf.C"(%[[B_OUT]])
// CHECK-NEXT: "tf._XlaSendFromHostV2"(%[[C_OUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-NEXT: %[[IF_COND:.*]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-NEXT: "tf.IfRegion"(%[[IF_COND]])
// CHECK-NEXT: %[[D_OUT:.*]] = "tf.D"(%[[C_OUT]])
// CHECK: "tf_device.cluster"
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"
// CHECK-NEXT: %[[B_OUT_DEVICE:.*]] = "tf._XlaHostComputeMlir"()
// CHECK-NEXT: %[[G_OUT:.*]] = "tf.G"
// CHECK-NEXT: "tf.WhileRegion"(%[[B_OUT_DEVICE]], %[[A_OUT]])
// CHECK: %[[H_OUT:.*]] = "tf.H"
// CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUT]])
// CHECK-NEXT: "tf.Yield"(%[[H_OUT]])
// CHECK: "tf._XlaHostComputeMlir"(%[[G_OUT]])
// CHECK-NEXT: "tf.IfRegion"(%[[G_OUT]])
// CHECK-NEXT: "tf._XlaHostComputeMlir"()
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ({
%3 = "tf.A"() : () -> (tensor<3xf32>)
%4 = "tf.B"() {_xla_outside_compilation="cluster0"} : () -> (tensor<i32>)
%6 = "tf.G"() : () -> (tensor<i1>)
"tf.WhileRegion"(%4, %3) ({
^bb0(%arg1: tensor<i32>, %arg2: tensor<3xf32>):
%7 = "tf.H"(%arg1) : (tensor<i32>) -> tensor<i1>
"tf.Yield"(%7) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<i32>, %arg2: tensor<3xf32>):
%8 = "tf.C"(%4) {_xla_outside_compilation="cluster0"} : (tensor<i32>) -> tensor<i32>
%10 = "tf.IfRegion"(%6) ({
%9 = "tf.D"(%8) {_xla_outside_compilation="cluster0"} : (tensor<i32>) -> tensor<3xf32>
"tf.Yield"(%9) : (tensor<3xf32>) -> ()
}, {
"tf.Yield"(%arg2) : (tensor<3xf32>) -> ()
}) { is_stateless = false} : (tensor<i1>) -> tensor<3xf32>
"tf.Yield"(%8, %10) : (tensor<i32>, tensor<3xf32>) -> ()
}) {is_stateless = false} : (tensor<i32>, tensor<3xf32>) -> (tensor<i32>, tensor<3xf32>)
%5 = "tf.E"() : () -> tensor<2xi32>
tf_device.return %5 : tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32>
tf_device.return %2 : tensor<2xi32>
func.return %1 : tensor<2xi32>
// Verifies that ops in between outside compile ops and depending on results
// from the host are moved after the host compute op so that dominance is not
// violated. tf.C op in this case.
// CHECK-LABEL: func @device_op_dominance
func.func @device_op_dominance() -> () {
// CHECK: tf.B
// CHECK: tf._XlaSendFromHost
// CHECK: tf._XlaRecvAtHost
// CHECK: tf.D
// CHECK: tf.A
// CHECK: tf._XlaHostComputeMlir
// CHECK: tf.C
// CHECK: tf.E
"tf_device.cluster"() ({
%0 = "tf.A"() : () -> (tensor<i32>)
%1 = "tf.B"() {_xla_outside_compilation = "cluster0"} : () -> (tensor<i32>)
"tf.C"(%1) : (tensor<i32>) -> ()
"tf.D"(%1, %0) {_xla_outside_compilation = "cluster0"} : (tensor<i32>, tensor<i32>) -> ()
"tf.E"(%0, %1) : (tensor<i32>, tensor<i32>) -> ()
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> ()
// Verifies that ops indirectly depending on results from the host are also
// moved after the host compute op. tf.E op in this case.
// CHECK-LABEL: func @device_op_dominance_with_indirect_dependency
func.func @device_op_dominance_with_indirect_dependency() -> () {
// CHECK: tf.B
// CHECK: tf._XlaRecvAtHost
// CHECK: tf.F
// CHECK: tf.A
// CHECK: tf.C
// CHECK: tf.D
// CHECK: tf.E
// CHECK: tf._XlaHostComputeMlir
// CHECK: tf.G
"tf_device.cluster"() ({
%0 = "tf.A"() : () -> (tensor<i32>)
%1 = "tf.B"() {_xla_outside_compilation = "cluster0"} : () -> (tensor<i32>)
%2 = "tf.C"(%1) : (tensor<i32>) -> (tensor<i32>)
%3 = "tf.D"() : () -> (tensor<i32>)
"tf.E"(%2, %3) : (tensor<i32>, tensor<i32>) -> ()
"tf.F"(%1, %0) {_xla_outside_compilation = "cluster0"} : (tensor<i32>, tensor<i32>) -> ()
"tf.G"(%0, %1) : (tensor<i32>, tensor<i32>) -> ()
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> ()
// Tests dynamically shaped input to outside compilation.
// CHECK-LABEL: func @dynamically_shaped_input
func.func @dynamically_shaped_input() -> () {
// CHECK: "tf_device.launch"
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"
// CHECK: "tf.B"(%[[RECV_OUTPUT]])
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: "tf._XlaHostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: host_mlir_module = ""
"tf_device.cluster"() ({
%0 = "tf.A"() : () -> (tensor<?xi32>)
"tf.B"(%0) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
"tf.C"() : () -> ()
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> ()
// Tests dynamically shaped output from outside compilation.
// CHECK-LABEL: func @dynamically_shaped_output
func.func @dynamically_shaped_output() -> () {
// CHECK: "tf_device.launch"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"()
// CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]]
// CHECK: "tf_device.cluster"
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"()
// CHECK-SAME: host_mlir_module = "module {\0A func.func @host_func() -> tensor<?xi32> {\0A %0 = \22tf.B\22() {_xla_outside_compilation = \22cluster1\22} : () -> tensor<?xi32> loc(#loc1)\0A return %0 : tensor<?xi32> loc(#loc1)\0A }
// CHECK: "tf.C"(%[[HOST_OUTPUT]])
"tf_device.cluster"() ({
%0 = "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> (tensor<?xi32>)
"tf.C"(%0) : (tensor<?xi32>) -> ()
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> ()
// Tests extraction of a single outside compiled cluster with arg input and single dynamic shape device->host input.
// CHECK-LABEL: func @mixed_dynamic_input_single_outside_compilation
func.func @mixed_dynamic_input_single_outside_compilation(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "tf.A"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"(%[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK-SAME: key = "host_compute_channel_0_args"
// CHECK: "tf.B"(%[[RECV_OUTPUT]]#0, %[[RECV_OUTPUT]]#1)
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: "tf._XlaHostComputeMlir"(%arg0, %[[A_OUTPUT]])
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ({
%3 = "tf.A"() : () -> (tensor<?xi32>)
%4 = "tf.B"(%arg0, %3) {_xla_outside_compilation = "cluster1"} : (tensor<2xi32>, tensor<?xi32>) -> (tensor<?xi32>)
%5 = "tf.C"(%4) : (tensor<?xi32>) -> tensor<2xi32>
tf_device.return %5 : tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32>
tf_device.return %2 : tensor<2xi32>
func.return %1 : tensor<2xi32>
// CHECK-LABEL: func @dynamic_input_chained_ops_outside_compilation
func.func @dynamic_input_chained_ops_outside_compilation(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "tf.A"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT]])
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[B_OUTPUT]])
// CHECK: "tf._XlaSendFromHostV2"(%[[C_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[A_OUTPUT]])
// CHECK: "tf.D"(%[[HOST_OUTPUT]])
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ({
%3 = "tf.A"() : () -> (tensor<?xi32>)
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>)
%5 = "tf.C"(%4) {_xla_outside_compilation = "cluster1"}: (tensor<?xi32>) -> tensor<?xi32>
%6 = "tf.D"(%5) : (tensor<?xi32>) -> tensor<2xi32>
tf_device.return %6 : tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32>
tf_device.return %2 : tensor<2xi32>
func.return %1 : tensor<2xi32>
// Tests that all inputs to a function are passed and all outputs are returned
// if any outputs of outside compilation cluster are dynamically shaped.
// CHECK-LABEL: func @dynamic_input_all_captured_outside_compilation
func.func @dynamic_input_all_captured_outside_compilation(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "tf.A"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"
// CHECK: %[[C_OUTPUT:[0-9]*]]:2 = "tf.C"(%[[RECV_OUTPUT]])
// CHECK: "tf._XlaSendFromHostV2"(%[[C_OUTPUT]]#0, %[[C_OUTPUT]]#1, %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"
// CHECK: %[[E_OUTPUT:[0-9]*]] = "tf.E"(%[[RECV_OUTPUT]]#0, %[[RECV_OUTPUT]]#1)
// CHECK: "tf._XlaSendFromHostV2"(%[[E_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[HOST_OUTPUT:[0-9]*]]:2 = "tf._XlaHostComputeMlir"(%[[A_OUTPUT]])
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[HOST_OUTPUT]]#0)
// CHECK: %[[HOST_OUTPUT2:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[D_OUTPUT]], %[[HOST_OUTPUT]]#1)
// CHECK: "tf.F"(%[[HOST_OUTPUT2]])
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ({
%3 = "tf.A"() : () -> (tensor<?xi32>)
%5:2 = "tf.C"(%3) {_xla_outside_compilation = "auto"} : (tensor<?xi32>) -> (tensor<?xi32>, tensor<2xi32>)
%6 = "tf.D"(%5#0) : (tensor<?xi32>) -> (tensor<?xi32>)
%7 = "tf.E"(%6, %5#1) {_xla_outside_compilation = "auto"} : (tensor<?xi32>, tensor<2xi32>) -> (tensor<?xi32>)
%8 = "tf.F"(%7) : (tensor<?xi32>) -> tensor<2xi32>
tf_device.return %8 : tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32>
tf_device.return %2 : tensor<2xi32>
func.return %1 : tensor<2xi32>
// Tests that all inputs to a function are passed and all outputs are returned for an outside compilation cluster
// with dynamic-shaped inputs/outputs. This test case tests that if another outside compiled returns
// only statically-shaped outputs, the value consumed by an outside compiled op with some dynamically shaped inputs is
// sent back to the host.
// CHECK-LABEL: func @dynamic_input_all_captured_mixed_shapes_outside_compilation
func.func @dynamic_input_all_captured_mixed_shapes_outside_compilation(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "tf.A"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"
// CHECK: %[[C_OUTPUT:[0-9]*]]:2 = "tf.C"(%[[RECV_OUTPUT]])
// CHECK: "tf._XlaSendFromHostV2"(%[[C_OUTPUT]]#0, %[[C_OUTPUT]]#1, %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"
// CHECK: %[[E_OUTPUT:[0-9]*]] = "tf.E"(%[[RECV_OUTPUT]]#0, %[[RECV_OUTPUT]]#1)
// CHECK: "tf._XlaSendFromHostV2"(%[[E_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[HOST_OUTPUT:[0-9]*]]:2 = "tf._XlaHostComputeMlir"(%[[A_OUTPUT]])
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[HOST_OUTPUT]]#0)
// CHECK: %[[HOST_OUTPUT2:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[HOST_OUTPUT]]#1)
// CHECK: "tf.F"(%[[HOST_OUTPUT2]])
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ({
%3 = "tf.A"() : () -> (tensor<2xi32>)
%4 = "tf.B"() : () -> (tensor<?xi32>)
%5:2 = "tf.C"(%3) {_xla_outside_compilation = "auto"} : (tensor<2xi32>) -> (tensor<2xi32>, tensor<2xi32>)
%6 = "tf.D"(%5#0) : (tensor<2xi32>) -> (tensor<2xi32>)
%7 = "tf.E"(%4, %5#1) {_xla_outside_compilation = "auto"} : (tensor<?xi32>, tensor<2xi32>) -> (tensor<?xi32>)
%8 = "tf.F"(%7) : (tensor<?xi32>) -> tensor<2xi32>
tf_device.return %8 : tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32>
tf_device.return %2 : tensor<2xi32>
func.return %1 : tensor<2xi32>
// Tests the case when an outside compiled op with only statically shaped inputs/outputs
// is in between two other outside compiled ops with some dynamic shapes.
// CHECK-LABEL: func @static_shapes_sandwiched_outside_compilation
func.func @static_shapes_sandwiched_outside_compilation(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "tf.A"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"
// CHECK: %[[C_OUTPUT:[0-9]*]]:2 = "tf.C"(%[[RECV_OUTPUT]])
// CHECK: "tf._XlaSendFromHostV2"(%[[C_OUTPUT]]#0, %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHostV2"
// CHECK: %[[E_OUTPUT:[0-9]*]] = "tf.E"(%[[RECV_OUTPUT]])
// CHECK: "tf._XlaSendFromHostV2"(%[[E_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: %[[F_OUTPUT:[0-9]*]] = "tf.F"(%[[C_OUTPUT]]#1)
// CHECK: "tf._XlaSendFromHostV2"(%[[F_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: %[[RECV_OUTPUT2:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"(%[[RECV_OUTPUT2]]#0, %[[RECV_OUTPUT2]]#1)
// CHECK: "tf._XlaSendFromHostV2"(%[[G_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[A_OUTPUT]])
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[HOST_OUTPUT]])
// CHECK: %[[HOST_OUTPUT2:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[B_OUTPUT]])
// CHECK: %[[HOST_OUTPUT3:[0-9]*]] = "tf._XlaHostComputeMlir"()
// CHECK: %[[HOST_OUTPUT4:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[HOST_OUTPUT3]], %[[HOST_OUTPUT2]])
// CHECK: "tf.H"(%[[HOST_OUTPUT4]])
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ({
%3 = "tf.A"() : () -> (tensor<2xi32>)
%4 = "tf.B"() : () -> (tensor<?xi32>)
%5:2 = "tf.C"(%3) {_xla_outside_compilation = "auto"} : (tensor<2xi32>) -> (tensor<2xi32>, tensor<2xi32>)
%6 = "tf.D"(%5#0) : (tensor<2xi32>) -> (tensor<2xi32>)
%7 = "tf.E"(%4) {_xla_outside_compilation = "auto"} : (tensor<?xi32>) -> (tensor<?xi32>)
%8 = "tf.F"(%5#1) {_xla_outside_compilation = "auto"} : (tensor<2xi32>) -> (tensor<2xi32>)
%9 = "tf.G"(%8, %7) {_xla_outside_compilation = "auto"} : (tensor<2xi32>, tensor<?xi32>) -> (tensor<?xi32>)
%10 = "tf.H"(%9) : (tensor<?xi32>) -> tensor<2xi32>
tf_device.return %10 : tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32>
tf_device.return %2 : tensor<2xi32>
func.return %1 : tensor<2xi32>
// Tests the case when an outside compiled op with some dynamically shaped input but static output
// comes before an outside compiled op with some dynamically shaped input/output. This ensures that
// for the second outside compiled op, all of the inputs are explicitly provided.
// CHECK-LABEL: func @static_output_dynamic_input_outside_compilation
func.func @static_output_dynamic_input_outside_compilation(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "tf.A"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-DAG: %[[PROGRAM_OUTPUT:.+]] = "tf._TPUCompileMlirPlaceholderProgramKey"
// CHECK-DAG: %[[DEVICE_ORDINAL:.+]] = "tf._TPUDeviceOrdinalPlaceholder"
// CHECK: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHostV2"
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[RECV_OUTPUT]]#0, %[[RECV_OUTPUT]]#1)
// CHECK: "tf._XlaSendFromHostV2"(%[[C_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: %[[RECV_OUTPUT2:[0-9]*]]:3 = "tf._XlaRecvAtHostV2"
// CHECK: %[[E_OUTPUT:[0-9]*]] = "tf.E"(%[[RECV_OUTPUT2]]#0, %[[RECV_OUTPUT2]]#1, %[[RECV_OUTPUT2]]#2)
// CHECK: "tf._XlaSendFromHostV2"(%[[E_OUTPUT]], %[[PROGRAM_OUTPUT]], %[[DEVICE_ORDINAL]])
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[A_OUTPUT]])
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[HOST_OUTPUT]])
// CHECK: %[[HOST_OUTPUT2:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[D_OUTPUT]], %[[B_OUTPUT]], %[[A_OUTPUT]])
// CHECK: "tf.F"(%[[HOST_OUTPUT2]])
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<2xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ({
%3 = "tf.A"() : () -> (tensor<2xi32>)
%4 = "tf.B"() : () -> (tensor<?xi32>)
%5 = "tf.C"(%4, %3) {_xla_outside_compilation = "auto"} : (tensor<?xi32>, tensor<2xi32>) -> (tensor<2xi32>)
%6 = "tf.D"(%5) : (tensor<2xi32>) -> (tensor<2xi32>)
%7 = "tf.E"(%6, %4, %3) {_xla_outside_compilation = "auto"} : (tensor<2xi32>, tensor<?xi32>, tensor<2xi32>) -> (tensor<?xi32>)
%8 = "tf.F"(%7) : (tensor<?xi32>) -> tensor<2xi32>
tf_device.return %8 : tensor<2xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<2xi32>
tf_device.return %2 : tensor<2xi32>
func.return %1 : tensor<2xi32>
// -----
// Tests that model parallelism does not affect outside compilation.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
// CHECK-LABEL: func @outside_compilation_model_parallelism
func.func @outside_compilation_model_parallelism() -> () {
// CHECK: "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-NEXT: "tf.B"
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
// CHECK: "tf_device.cluster"
// CHECK-NEXT: "tf.A"
// CHECK: device_assignment = [], num_cores_per_replica = 2 : i64, topology = ""
"tf_device.cluster"() ({
"tf.A"() : () -> ()
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
"tf.C"() : () -> ()
}) {num_cores_per_replica = 2, topology = "", device_assignment = []} : () -> ()