-tf-device-cluster-outlining
: Outlines regions of tf_device.cluster operationsThis pass outlines the body of a tf_device.cluster
into a function and replaces the tf_device.cluster
op with an equivalent tf_device.cluster_func
op. Implicit operands will be captured and materialized as explicit arguments to the newly created functions and associated tf_device.cluster_func
ops.
For example, the following:
func @computation(%arg0: tensor<i32>) -> tensor<i32> { %cluster = "tf_device.cluster"() ( { %identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32> tf_device.return %identity : tensor<i32> }) : () -> (tensor<i32>) return %cluster : tensor<i32> }
will be transformed into:
func @computation(%arg0: tensor<i32>) -> tensor<i32> { %cluster = "tf_device.cluster_func"(%arg0) {func = @_func} : (tensor<i32>) -> tensor<i32> return %cluster : tensor<i32> } func @_func(%arg0: tensor<i32>) -> tensor<i32> { %identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32> return %identity : tensor<i32> }
-tf-device-constant-sinking
: Sinks constants implicitly captured in a tf_device.cluster region.This pass sinks implicitly captured constants (tf.Const
ops) used by and into a tf_device.cluster
region. Performing this prior to outlining will reduce the number of arguments of the outlined function.
For example, the following:
func @cluster() -> tensor<i32> { %const = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32> %cluster = "tf_device.cluster"() ( { %identity = "tf.Identity"(%const) : (tensor<i32>) -> tensor<i32> tf_device.return %identity : tensor<i32> }) : () -> (tensor<i32>) return %cluster : tensor<i32> }
will be transformed into:
func @cluster() -> tensor<i32> { %cluster = "tf_device.cluster"() ( { %const = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32> %identity = "tf.Identity"(%const) : (tensor<i32>) -> tensor<i32> tf_device.return %identity : tensor<i32> }) : () -> (tensor<i32>) return %cluster : tensor<i32> }
-tf-executor-graph-pruning
: Prunes unreachable ops in a tf_executor.graphThis pass removes ops from a tf_executor.graph
that are not transitively, via data or control dependencies, connected to the associated tf_executor.fetch
op. The order of ops will be preserved. Functions named main
with no tf.entry_function
attribute will not be pruned, as such graphs/functions may have been imported from a V1 TensorFlow graph, where feeds/fetches/targets are not provided at certain stages of IR transformation (e.g. pre-placement).
For example, the following:
func @graph(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> { %graph = tf_executor.graph { %transitive_reachable_data:2 = tf_executor.island wraps "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32> %reachable_data:2 = tf_executor.island wraps "tf.Identity"(%transitive_reachable_data#0) : (tensor<i32>) -> tensor<i32> %unreachable_data:2 = tf_executor.island wraps "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32> %transitive_reachable_control = tf_executor.island wraps "tf.NoOp"() : () -> () %reachable_control = tf_executor.island(%transitive_reachable_control) wraps "tf.NoOp"() : () -> () %unreachable_control = tf_executor.island wraps "tf.NoOp"() : () -> tensor<i32> tf_executor.fetch %reachable_data#0, %reachable_control : tensor<i32>, !tf_executor.control } return %graph : tensor<i32> }
will be transformed into:
func @graph(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> { %graph = tf_executor.graph { %transitive_reachable_data:2 = tf_executor.island wraps "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32> %reachable_data:2 = tf_executor.island wraps "tf.Identity"(%transitive_reachable_data#0) : (tensor<i32>) -> tensor<i32> %transitive_reachable_control = tf_executor.island wraps "tf.NoOp"() : () -> () %reachable_control = tf_executor.island(%transitive_reachable_control) wraps "tf.NoOp"() : () -> () tf_executor.fetch %reachable_data#0, %reachable_control : tensor<i32>, !tf_executor.control } return %graph : tensor<i32> }
-tf-executor-to-functional-conversion
: Lifts tf_executor.island inner ops from a tf_executor.graphThis pass converts tf_executor.graphs consisting of only tf_executor.islands and a tf_executor.fetch into a sea of nodes consisting of TensorFlow Dialect ops by lifting such ops out of a tf_executor.graph's tf_executor.islands. If V1 control flow ops are present in a tf_executor.graph, an error will be returned.
For example, the following:
func @my_fn(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) { %graph_results:2 = tf_executor.graph { %island_0_result, %island_0_control = tf_executor.island { %identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32> tf_executor.yield %identity : tensor<i32> } %island_1_result, %island_1_control = tf_executor.island { %identity_n:2 = "tf.IdentityN"(%arg1, %island_0_result) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>) tf_executor.yield %identity_n#0 } tf_executor.fetch %island_0_result, %island_1_result : tensor<i32>, tensor<i32> } return %graph_results#0, %graph_results#1 : tensor<i32>, tensor<i32> }
will be transformed into:
func @my_fn(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) { %identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32> %identity_n:2 = "tf.IdentityN"(%arg1, %identity) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>) return %identity, %identity_n#0 : tensor<i32>, tensor<i32> }
-tf-functional-control-flow-to-regions
: Transforms functional control flow operations to their region-based counterpartsThis pass transforms functional control flow operations in the TensorFlow dialect to their region-based counterparts, i.e., tf.If
is transformed to tf.IfRegion
and tf.While
is transformed to tf.WhileRegion
.
For example, this functional operation
%0 = "tf.If"(%arg0, %arg1) { then_branch = @then_branch_func, else_branch = @else_branch_func, is_stateless = false } : (tensor<i1>, tensor<*xf32>) -> tensor<*xf32>
will be transformed into this region-based operation
%0 = "tf.IfRegion"(%arg0) ( { %1 = call @then_branch_func(%arg1) : (tensor<*xf32>) -> tensor<*xf32> "tf.Yield"(%1) : (tensor<*xf32>) -> () }, { %1 = call @else_branch_func(%arg1) : (tensor<*xf32>) -> tensor<*xf32> "tf.Yield"(%1) : (tensor<*xf32>) -> () }) {is_stateless = false} : (tensor<i1>) -> tensor<*xf32>
-tf-mark-ops-for-outside-compilation
: Marks ops in device cluster for outside compilation if they are unsupported on device.This pass marks unsupported ops in a device cluster with _xla_outside_compilation
attribute so the operations will run on the host instead of the device. Unsupported ops are ops that can not be code generated to run on the device for the cluster including:
This pass is conservative in that it will mark all ops for outside compilation that can not be compiled for the device. Exceptions for this are added for ops that will be rewritten or decomposed before compiling on device.
For example, tf_device.cluster op with an unsupported op, tf.UnsupportedOp:
func @unsupported_op() -> tensor<i32> { %0 = "tf_device.cluster"() ( { %1 = "tf.UnsupportedOp"() : () -> tensor<i32> %2 = "tf.Identity"(%1) : (tensor<i32>) -> tensor<i32> tf_device.return %2 : tensor<i32> }) {allow_soft_placement = true, num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<i32> return %0 : tensor<i32> }
will mark tf.UnsupportedOp with _xla_outside_compilation
attribute:
func @unsupported_op() -> tensor<i32> { %0 = "tf_device.cluster"() ( { %1 = "tf.UnsupportedOp"() {_xla_outside_compilation = "auto0"} : () -> tensor<i32> %2 = "tf.Identity"(%1) : (tensor<i32>) -> tensor<i32> tf_device.return %2 : tensor<i32> }) {allow_soft_placement = true, device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""} : () -> tensor<i32> return %0 : tensor<i32> }
-tf-region-control-flow-to-functional
: Transforms region-based control flow operations to their functional counterpartsThis pass transforms region-based control flow operations in the TensorFlow dialect to their functional counterparts, i.e., tf.IfRegion
is transformed to tf.If
and tf.WhileRegion
is transformed to tf.While
.
For example, this region-based operation
%0 = "tf.IfRegion"(%arg0) ( { %1 = call @then_branch_func(%arg1) : (tensor<*xf32>) -> tensor<*xf32> "tf.Yield"(%1) : (tensor<*xf32>) -> () }, { %1 = call @else_branch_func(%arg1) : (tensor<*xf32>) -> tensor<*xf32> "tf.Yield"(%1) : (tensor<*xf32>) -> () }) {is_stateless = false} : (tensor<i1>) -> tensor<*xf32>
will be transformed into this functional operation
%0 = "tf.If"(%arg0, %arg1) { then_branch = @then_branch_func, else_branch = @else_branch_func, is_stateless = false } : (tensor<i1>, tensor<*xf32>) -> tensor<*xf32>
-tf-shape-inference
: Simple Shape Inference on TensorFlow Dialect-max-iterations : Maximum shape inference iterations
-tf-tpu-cluster-formation
: Forms clusters from operations assigned to the same TPU computationTPU computations from the frontend are composed of a tf.TPUReplicateMetadata
op, a subgraph of ops (TensorFlow Dialect) each with a matching _tpu_replicate
attribute relative to the associated tf.TPUReplicateMetadata
op, and optionally tf.TPUReplicatedInput
and tf.TPUReplicatedOutput
ops feeding in inputs and outputs to and from a replicated TPU computation. The number of times a TPU computation is replicated is defined in the tf.TPUReplicateMetadata
op (num_replicas
attribute) and operand and result sizes of tf.TPUReplicatedInput
and tf.TPUReplicatedOutput
respectively must match, excluding packed tensors. It is also assumed ops of the same TPU computation do not have ops outside of the TPU computation that are both inputs and outputs to the same TPU computation.
This pass takes the TPU computation subgraph, moves them into a tf_device.cluster
, and copies over attributes from the associated tf.TPUReplicateMetadata
op to the newly created tf_device.cluster
. If the computation is replicated (num_replicas
> 1), the num_replicas
attribute is not copied over but instead the tf_device.cluster
is further wrapped with a tf_device.replicate
, and associated tf.TPUReplicatedInput
and tf.TPUReplicatedOutput
ops are replaced as the tf_device.replicate
operands and results. Otherwise, the single operands and results of the associated tf.TPUReplicatedInput
and tf.TPUReplicatedOutput
ops are simply forwarded to the tf_device.cluster
.
For example, the following non replicated computation:
func @tpu_computation(%arg0: tensor<i32>) -> tensor<i32> { // Metadata op for cluster `cluster` with 1 replica, 1 core per replica and // with topology `<topology>`. "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", num_relicas = 1, num_cores_per_replica = 1, topology = "<topology>", device_assignment = [], padding_map = []} : () -> () %replicated_input = "tf.TPUReplicatedInput"(%arg0) : (tensor<i32>) -> tensor<i32> %identity = "tf.Identity"(%replicated_input) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32> %replicated_output = "tf.TPUReplicatedOutput(%identity) : (tensor<i32>) -> tensor<i32> return %replicated_output : tensor<i32> }
will be transformed into:
func @tpu_computation(%arg0: tensor<i32>) -> tensor<i32> { %cluster = "tf_device.cluster"() ( { %identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32> tf_device.return %identity : tensor<i32> }) {_tpu_replicate = "cluster", num_cores_per_replica = 1, topology = "topology", device_assignment = [], padding_map = []} : () -> (tensor<i32>) return %cluster : tensor<i32> }
The following replicated computation:
func @tpu_computation(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) { "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", num_relicas = 2, num_cores_per_replica = 1, topology = "topology", device_assignment = [], padding_map = []} : () -> () %replicated_input = "tf.TPUReplicatedInput"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32> %identity = "tf.Identity"(%replicated_input) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32> %replicated_output:2 = "tf.TPUReplicatedOutput(%identity) : (tensor<i32>) -> (tensor<i32>, tensor<i32>) return %replicated_output#0, %replicated_output#1 : tensor<i32>, tensor<i32> }
will be transformed into:
func @tpu_computation(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) { %replicate:2 = tf_device.replicate([%arg0, %arg1] as %replicated_input) {n = 2 : i32} { %cluster = "tf_device.cluster"() ( { %identity = "tf.Identity"(%replicated_input) : (tensor<i32>) -> tensor<i32> tf_device.return %identity : tensor<i32> }) {_tpu_replicate = "cluster", num_cores_per_replica = 1, topology = "topology", device_assignment = [], padding_map = []} : () -> (tensor<i32>) tf_device.return %cluster : tensor<i32> } return %replicate#0, %replicate#1 : tensor<i32>, tensor<i32> }
-tf-tpu-extract-outside-compilation
: Extracts TPU outside compilation computation to a separate tf_device.parallel_execute region.This pass extracts a CPU computation cluster with _xla_outside_compilation
annotation, which denotes ops that should be run on CPU/host, from a TPU cluster. Each outside compilation cluster is moved to a tf_device.parallel_execute region. The TPU cluster is also moved to a tf_device.parallel_execute region. Communication ops between device and host are added to pass inputs/outputs to/from the outside compiled region.
For example, the following tf_device.cluster with an op marked for xla_outside_compilation
:
func @outside_compilation() -> tensor<f32> { %0 = "tf_device.cluster"() ( { %1 = "tf.Const"() {_xla_outside_compilation = "0", value = dense<1.0> : tensor<f32>} : () -> (tensor<f32>) %2 = "tf.Identity"(%1) {_xla_outside_compilation = "0"} : (tensor<f32>) -> (tensor<f32>) %3 = "tf.AddV2"(%1, %2) : (tensor<f32>, tensor<f32>) -> (tensor<f32>) tf_device.return %3 : tensor<f32> }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<f32> return %0 : tensor<f32> }
will become a tf_device.parallel_execute op with a CPU/host region and a tf_device.cluster with communication ops to send data to/from device/host:
func @outside_compilation() -> tensor<f32> { %0 = "tf_device.parallel_execute"() ( { "tf_device.launch"() ( { %1 = "tf._TPUCompileMlirPlaceholderProgramKey"() : () -> tensor<3x!tf.string> %2 = "tf._XlaRecvAtHost"(%1) {device_ordinal = 0 : i64, key = "host_compute_channel_0_0_args"} : (tensor<3x!tf.string>) -> tensor<f32> %3 = "tf.Identity"(%2) : (tensor<f32>) -> tensor<f32> "tf._XlaSendFromHost"(%3, %1) {device_ordinal = 0 : i64, key = "host_compute_channel_0_0_retvals"} : (tensor<f32>, tensor<3x!tf.string>) -> () tf_device.return }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> () tf_device.return }, { %1 = "tf_device.cluster"() ( { %2 = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32> %3 = "tf._XlaHostComputeMlir"(%2) {recv_key = "host_compute_channel_0_0_retvals", send_key = "host_compute_channel_0_0_args", tpu_core = 0 : i64} : (tensor<f32>) -> tensor<f32> %4 = "tf.AddV2"(%2, %3) : (tensor<f32>, tensor<f32>) -> tensor<f32> tf_device.return %4 : tensor<f32> }) {device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""} : () -> tensor<f32> tf_device.return %1 : tensor<f32> }) : () -> tensor<f32> return %0 : tensor<f32> }
-tf-tpu-resource-partition
: Partitions unpartitioned resource read/write to partitioned resource variables.This pass creates individual resource reads/writes from the unpartitioned resource variable (from tf.TPUPartitionedInput
) to individual partitioned resource variables (tf.TPUPartitionedInput
operands). As resource op decomposition/lifting occurs with the unpartitioned resource variables, transforming the IR in such a manner will allow for subsequent passes to operate on individual resource variable handles per core/device.
For example, the following:
func @cluster(%arg0: tensor<!tf.resource<tensor<i32>>>, %arg1: tensor<!tf.resource<tensor<i32>>>) { %partitioned_variable = "tf.TPUPartitionedInput"(%arg0, %arg1) {N = 2 : i64, _XlaSharding = "", partition_dim = -1 : i64} : (tensor<!tf.resource<tensor<i32>>>, tensor<!tf.resource<tensor<i32>>>) -> tensor<!tf.resource<tensor<i32>>> %read = "tf.ReadVariableOp"(%partitioned_variable) : (tensor<!tf.resource<tensor<i32>>>) -> tensor<i32> %computation = "tf_device.cluster_func"(%read) {func = @computation, use_spmd_for_xla_partitioning = true} : (tensor<i32>) -> tensor<i32> "tf.AssignVariableOp"(%partitioned_variable, %computation) : (tensor<!tf.resource<tensor<i32>>>, tensor<i32>) -> () return } func @computation(%arg0: tensor<i32>) -> tensor<i32> { return %arg0: tensor<i32> }
will be transformed into:
func @cluster(%arg0: tensor<!tf.resource<tensor<i32>>>, %arg1: tensor<!tf.resource<tensor<i32>>>) { %read0 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<i32>>>) -> tensor<i32> %read1 = "tf.ReadVariableOp"(%arg1) : (tensor<!tf.resource<tensor<i32>>>) -> tensor<i32> %partitioned_input = "tf.TPUPartitionedInput"(%read0, %read1) {N = 2 : i64, _XlaSharding = "", partition_dim = -1 : i64} : (tensor<i32>, tensor<i32>) -> tensor<i32> %computation = "tf_device.cluster_func"(%partitioned_input) {func = @computation, use_spmd_for_xla_partitioning = true} : (tensor<i32>) -> tensor<i32> %partitioned_output:2 = "tf.TPUPartitionedOutput"(%computation) {N = 2 : i64, _XlaSharding = "", partition_dim = -1 : i64} : (tensor<i32>) -> (tensor<i32>, tensor<i32>) "tf.AssignVariableOp"(%arg0, %partitioned_output#0) : (tensor<!tf.resource<tensor<i32>>>, tensor<i32>) -> () "tf.AssignVariableOp"(%arg1, %partitioned_output#1) : (tensor<!tf.resource<tensor<i32>>>, tensor<i32>) -> () return } func @computation(%arg0: tensor<i32>) -> tensor<i32> { return %arg0: tensor<i32> }
-tf-tpu-resource-read-for-write
: Inserts tf.ReadVariableOp inputs to a TPU cluster for resource writes with no readsThis pass materializes tf.ReadVariableOp
inputs to an outlined TPU computation for resource variables where only writes are present so later in the pipeline such resource variables can be fused with generated tf.TPUExecute
ops, which only supports resource variable read or read + write. For all TPU computations, resource variables are required to be initialized prior to execution. Write only resource variable uses can be generated currently via packed tensor uses.
For example, the following:
func @write_only_resource(%value: tensor<i32>, %resource: tensor<*x!tf.resource<tensor<i32>>>) { %0 = "tf_device.cluster_func"(%value) {func = @cluster} : (tensor<i32>) -> tensor<i32> "tf.AssignVariableOp"(%resource, %0) : (tensor<*x!tf.resource<tensor<i32>>>, tensor<i32>) -> () return } func @cluster(%arg0: tensor<i32>) -> tensor<i32> { %identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32> return %identity : tensor<i32> }
will be transformed into:
func @write_only_resource(%value: tensor<i32>, %resource: tensor<*x!tf.resource<tensor<i32>>>) { %resource_read = "tf.ReadVariableOp"(%resource) : (tensor<*x!tf.resource<tensor<i32>>>) -> tensor<i32> %0 = "tf_device.cluster_func"(%value, %resource_read) {func = @cluster} : (tensor<i32>, tensor<i32>) -> tensor<i32> "tf.AssignVariableOp"(%resource, %0) : (tensor<*x!tf.resource<tensor<i32>>>, tensor<i32>) -> () return } func @cluster(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> { %identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32> return %identity : tensor<i32> }
-tf-tpu-rewrite
: Rewrites a tf_device.cluster_func
on TPUs into TPU runtime operations.This pass rewrites a tf_device.cluster_func
operation into a sequence of tf._TPUCompileMlir
and tf.TPUExecute
operations. tf._TPUCompileMlir
contains a MLIR module that is functionally equivalent to the function referenced by tf_device.cluster_func
. This makes the module to be jit-compiled and executed on TPU. If it is not possible to rewrite the operation or device assignment fails, a failure will be returned.
Note, many parameters to the tf_device.cluster_func
are ommited in this and following examples. For example, a non replicated tf_device.cluster_func
:
func @tf_tpu_rewrite(%arg0: tensor<i8>) { %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @func} : (tensor<i8>) -> tensor<i8> return }
will be rewritten as:
func @tf_tpu_rewrite(%arg0: tensor<i8>) { %0:2 = "tf_device.launch"() ( { %compilation_status, %program = "tf._TPUCompileMlir"() {mlir_module = "<serialized func>"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>) tf_device.return %compilation_status, %program : tensor<!tf.string>, tensor<3x!tf.string> }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>) "tf_device.launch"() ( { "tf.TPUCompileSucceededAssert"(%0#0) : (tensor<!tf.string>) -> () tf_device.return }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> () %1 = "tf_device.launch"() ( { %2 = "tf.TPUExecute"(%arg0, %0#1) : (tensor<i8>, tensor<3x!tf.string>) -> tensor<i8> tf_device.return %2 : tensor<i8> }) {device = "/job:worker/replica:0/task:0/device:TPU:0"} : () -> tensor<i8> return }
A replicated tf_device.cluster_func
:
func @tf_tpu_rewrite(%arg0: tensor<i8>, %arg1: tensor<i8>) { %0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<i8>) {n = 2 : i32} { %1 = "tf_device.cluster_func"(%ri) {_tpu_replicate = "cluster0", func = @func} : (tensor<i8>) -> tensor<i8> tf_device.return %1 : tensor<i8> } return }
will be rewritten as:
func @tf_tpu_rewrite(%arg0: tensor<i8>, %arg1: tensor<i8>) { %0:2 = tf_device.replicate([%arg0, %arg1] as %arg2: tensor<i8>) {devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"], TPU_REPLICATED_HOST = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:CPU:0"]}, n = 2 : i32} { %1:2 = "tf_device.launch"() ( { %compilation_status, %program = "tf._TPUCompileMlir"() {mlir_module = "<serialized func>"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>) tf_device.return %compilation_status, %program : tensor<!tf.string>, tensor<3x!tf.string> }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>) "tf_device.launch"() ( { "tf.TPUCompileSucceededAssert"(%1#0) : (tensor<!tf.string>) -> () tf_device.return }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> () %2 = "tf_device.launch"() ( { %3 = "tf.TPUExecute"(%arg2, %1#1) : (tensor<i8>, tensor<3x!tf.string>) -> tensor<i8> tf_device.return %3 : tensor<i8> }) {device = "TPU_REPLICATED_CORE_0"} : () -> tensor<i8> tf_device.return %2 : tensor<i8> } return } A non replicated `tf_device.cluster_func` with the model parallelism: ```mlir func @tf_tpu_rewrite(%arg0: tensor<8xi32>) -> tensor<8xi32> { %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @func, num_cores_per_replica = 2, input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>) -> tensor<8xi32> return %0 : tensor<8xi32> }
will be rewritten as:
func @tf_tpu_rewrite(%arg0: tensor<8xi32>) -> tensor<8xi32> { %0:3 = "tf_device.launch"() ( { %compilation_status, %program:2 = "tf._TPUCompileMlir"() {mlir_module = "<serialized func>"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>, tensor<3x!tf.string>) tf_device.return %compilation_status, %program#0, %program#1 : tensor<!tf.string>, tensor<3x!tf.string>, tensor<3x!tf.string> }) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>, tensor<3x!tf.string>) "tf_device.launch"() ( { "tf.TPUCompileSucceededAssert"(%0#0) : (tensor<!tf.string>) -> () tf_device.return }) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> () %1 = "tf_device.parallel_execute"() ( { %2 = "tf_device.launch"() ( { %3 = "tf.TPUExecute"(%arg0, %0#1) : (tensor<8xi32>, tensor<3x!tf.string>) -> tensor<8xi32> tf_device.return %3 : tensor<8xi32> }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> tensor<8xi32> tf_device.return %2 : tensor<8xi32> }, { "tf_device.launch"() ( { "tf.TPUExecute"(%0#2) : (tensor<3x!tf.string>) -> () tf_device.return }) {device = "/job:localhost/replica:0/task:0/device:TPU:1"} : () -> () tf_device.return }) : () -> tensor<8xi32> return %1 : tensor<8xi32> }