-tf-device-cluster-outlining: Outlines regions of tf_device.cluster operations

This pass outlines the body of a tf_device.cluster into a function and replaces the tf_device.cluster op with an equivalent tf_device.cluster_func op. Implicit operands will be captured and materialized as explicit arguments to the newly created functions and associated tf_device.cluster_func ops.

For example, the following:

func @computation(%arg0: tensor<i32>) -> tensor<i32> {
  %cluster = "tf_device.cluster"() ( {
    %identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
    tf_device.return %identity : tensor<i32>
  }) : () -> (tensor<i32>)
  return %cluster : tensor<i32>
}

will be transformed into:

func @computation(%arg0: tensor<i32>) -> tensor<i32> {
  %cluster = "tf_device.cluster_func"(%arg0) {func = @_func} : (tensor<i32>) -> tensor<i32>
  return %cluster : tensor<i32>
}

func @_func(%arg0: tensor<i32>) -> tensor<i32> {
  %identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
  return %identity : tensor<i32>
}

-tf-device-constant-sinking: Sinks constants implicitly captured in a tf_device.cluster region.

This pass sinks implicitly captured constants (tf.Const ops) used by and into a tf_device.cluster region. Performing this prior to outlining will reduce the number of arguments of the outlined function.

For example, the following:

func @cluster() -> tensor<i32> {
  %const = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
  %cluster = "tf_device.cluster"() ( {
    %identity = "tf.Identity"(%const) : (tensor<i32>) -> tensor<i32>
    tf_device.return %identity : tensor<i32>
  }) : () -> (tensor<i32>)
  return %cluster : tensor<i32>
}

will be transformed into:

func @cluster() -> tensor<i32> {
  %cluster = "tf_device.cluster"() ( {
    %const = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
    %identity = "tf.Identity"(%const) : (tensor<i32>) -> tensor<i32>
    tf_device.return %identity : tensor<i32>
  }) : () -> (tensor<i32>)
  return %cluster : tensor<i32>
}

-tf-executor-graph-pruning: Prunes unreachable ops in a tf_executor.graph

This pass removes ops from a tf_executor.graph that are not transitively, via data or control dependencies, connected to the associated tf_executor.fetch op. The order of ops will be preserved. Functions named main with no tf.entry_function attribute will not be pruned, as such graphs/functions may have been imported from a V1 TensorFlow graph, where feeds/fetches/targets are not provided at certain stages of IR transformation (e.g. pre-placement).

For example, the following:

func @graph(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
  %graph = tf_executor.graph {
    %transitive_reachable_data:2 = tf_executor.island wraps "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
    %reachable_data:2 = tf_executor.island wraps "tf.Identity"(%transitive_reachable_data#0) : (tensor<i32>) -> tensor<i32>
    %unreachable_data:2 = tf_executor.island wraps "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
    %transitive_reachable_control = tf_executor.island wraps "tf.NoOp"() : () -> ()
    %reachable_control = tf_executor.island(%transitive_reachable_control) wraps "tf.NoOp"() : () -> ()
    %unreachable_control = tf_executor.island wraps "tf.NoOp"() : () -> tensor<i32>
    tf_executor.fetch %reachable_data#0, %reachable_control : tensor<i32>, !tf_executor.control
  }
  return %graph : tensor<i32>
}

will be transformed into:

func @graph(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
  %graph = tf_executor.graph {
    %transitive_reachable_data:2 = tf_executor.island wraps "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
    %reachable_data:2 = tf_executor.island wraps "tf.Identity"(%transitive_reachable_data#0) : (tensor<i32>) -> tensor<i32>
    %transitive_reachable_control = tf_executor.island wraps "tf.NoOp"() : () -> ()
    %reachable_control = tf_executor.island(%transitive_reachable_control) wraps "tf.NoOp"() : () -> ()
    tf_executor.fetch %reachable_data#0, %reachable_control : tensor<i32>, !tf_executor.control
  }
  return %graph : tensor<i32>
}

-tf-executor-to-functional-conversion: Lifts tf_executor.island inner ops from a tf_executor.graph

This pass converts tf_executor.graphs consisting of only tf_executor.islands and a tf_executor.fetch into a sea of nodes consisting of TensorFlow Dialect ops by lifting such ops out of a tf_executor.graph's tf_executor.islands. If V1 control flow ops are present in a tf_executor.graph, an error will be returned.

For example, the following:

func @my_fn(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
  %graph_results:2 = tf_executor.graph {
    %island_0_result, %island_0_control = tf_executor.island {
      %identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
      tf_executor.yield %identity : tensor<i32>
    }
    %island_1_result, %island_1_control = tf_executor.island {
      %identity_n:2 = "tf.IdentityN"(%arg1, %island_0_result) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
      tf_executor.yield %identity_n#0
    }
    tf_executor.fetch %island_0_result, %island_1_result : tensor<i32>, tensor<i32>
  }
  return %graph_results#0, %graph_results#1 : tensor<i32>, tensor<i32>
}

will be transformed into:

func @my_fn(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
  %identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
  %identity_n:2 = "tf.IdentityN"(%arg1, %identity) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
  return %identity, %identity_n#0 : tensor<i32>, tensor<i32>
}

-tf-functional-control-flow-to-regions: Transforms functional control flow operations to their region-based counterparts

This pass transforms functional control flow operations in the TensorFlow dialect to their region-based counterparts, i.e., tf.If is transformed to tf.IfRegion and tf.While is transformed to tf.WhileRegion.

For example, this functional operation

  %0 = "tf.If"(%arg0, %arg1) {
    then_branch = @then_branch_func, else_branch = @else_branch_func, is_stateless = false
  } : (tensor<i1>, tensor<*xf32>) -> tensor<*xf32>

will be transformed into this region-based operation

    %0 = "tf.IfRegion"(%arg0) ( {
      %1 = call @then_branch_func(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
      "tf.Yield"(%1) : (tensor<*xf32>) -> ()
    },  {
      %1 = call @else_branch_func(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
      "tf.Yield"(%1) : (tensor<*xf32>) -> ()
    }) {is_stateless = false} : (tensor<i1>) -> tensor<*xf32>

-tf-mark-ops-for-outside-compilation: Marks ops in device cluster for outside compilation if they are unsupported on device.

This pass marks unsupported ops in a device cluster with _xla_outside_compilation attribute so the operations will run on the host instead of the device. Unsupported ops are ops that can not be code generated to run on the device for the cluster including:

  1. String operations on TPUs.
  2. Operations that don't have a kernel defined for the device.

This pass is conservative in that it will mark all ops for outside compilation that can not be compiled for the device. Exceptions for this are added for ops that will be rewritten or decomposed before compiling on device.

For example, tf_device.cluster op with an unsupported op, tf.UnsupportedOp:

func @unsupported_op() -> tensor<i32> {
  %0 = "tf_device.cluster"() ( {
    %1 = "tf.UnsupportedOp"() : () -> tensor<i32>
    %2 = "tf.Identity"(%1) : (tensor<i32>) -> tensor<i32>
    tf_device.return %2 : tensor<i32>
  }) {allow_soft_placement = true, num_cores_per_replica = 1, topology =  "", device_assignment =  []} : () -> tensor<i32>
  return %0 : tensor<i32>
}

will mark tf.UnsupportedOp with _xla_outside_compilation attribute:

func @unsupported_op() -> tensor<i32> {
  %0 = "tf_device.cluster"() ( {
    %1 = "tf.UnsupportedOp"() {_xla_outside_compilation = "auto0"} : () -> tensor<i32>
    %2 = "tf.Identity"(%1) : (tensor<i32>) -> tensor<i32>
    tf_device.return %2 : tensor<i32>
  }) {allow_soft_placement = true, device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""} : () -> tensor<i32>
  return %0 : tensor<i32>
}

-tf-region-control-flow-to-functional: Transforms region-based control flow operations to their functional counterparts

This pass transforms region-based control flow operations in the TensorFlow dialect to their functional counterparts, i.e., tf.IfRegion is transformed to tf.If and tf.WhileRegion is transformed to tf.While.

For example, this region-based operation

    %0 = "tf.IfRegion"(%arg0) ( {
      %1 = call @then_branch_func(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
      "tf.Yield"(%1) : (tensor<*xf32>) -> ()
    },  {
      %1 = call @else_branch_func(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
      "tf.Yield"(%1) : (tensor<*xf32>) -> ()
    }) {is_stateless = false} : (tensor<i1>) -> tensor<*xf32>

will be transformed into this functional operation

  %0 = "tf.If"(%arg0, %arg1) {
    then_branch = @then_branch_func, else_branch = @else_branch_func, is_stateless = false
  } : (tensor<i1>, tensor<*xf32>) -> tensor<*xf32>

-tf-shape-inference: Simple Shape Inference on TensorFlow Dialect

Options

-max-iterations : Maximum shape inference iterations

-tf-tpu-cluster-formation: Forms clusters from operations assigned to the same TPU computation

TPU computations from the frontend are composed of a tf.TPUReplicateMetadata op, a subgraph of ops (TensorFlow Dialect) each with a matching _tpu_replicate attribute relative to the associated tf.TPUReplicateMetadata op, and optionally tf.TPUReplicatedInput and tf.TPUReplicatedOutput ops feeding in inputs and outputs to and from a replicated TPU computation. The number of times a TPU computation is replicated is defined in the tf.TPUReplicateMetadata op (num_replicas attribute) and operand and result sizes of tf.TPUReplicatedInput and tf.TPUReplicatedOutput respectively must match, excluding packed tensors. It is also assumed ops of the same TPU computation do not have ops outside of the TPU computation that are both inputs and outputs to the same TPU computation.

This pass takes the TPU computation subgraph, moves them into a tf_device.cluster, and copies over attributes from the associated tf.TPUReplicateMetadata op to the newly created tf_device.cluster. If the computation is replicated (num_replicas > 1), the num_replicas attribute is not copied over but instead the tf_device.cluster is further wrapped with a tf_device.replicate, and associated tf.TPUReplicatedInput and tf.TPUReplicatedOutput ops are replaced as the tf_device.replicate operands and results. Otherwise, the single operands and results of the associated tf.TPUReplicatedInput and tf.TPUReplicatedOutput ops are simply forwarded to the tf_device.cluster.

For example, the following non replicated computation:

func @tpu_computation(%arg0: tensor<i32>) -> tensor<i32> {
  // Metadata op for cluster `cluster` with 1 replica, 1 core per replica and
  // with topology `<topology>`.
  "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", num_relicas = 1, num_cores_per_replica = 1, topology = "<topology>", device_assignment = [], padding_map = []} : () -> ()
  %replicated_input = "tf.TPUReplicatedInput"(%arg0) : (tensor<i32>) -> tensor<i32>
  %identity = "tf.Identity"(%replicated_input) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
  %replicated_output = "tf.TPUReplicatedOutput(%identity) : (tensor<i32>) -> tensor<i32>
  return %replicated_output : tensor<i32>
}

will be transformed into:

func @tpu_computation(%arg0: tensor<i32>) -> tensor<i32> {
  %cluster = "tf_device.cluster"() ( {
    %identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
    tf_device.return %identity : tensor<i32>
  }) {_tpu_replicate = "cluster", num_cores_per_replica = 1, topology = "topology", device_assignment = [], padding_map = []} : () -> (tensor<i32>)
  return %cluster : tensor<i32>
}

The following replicated computation:

func @tpu_computation(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
  "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", num_relicas = 2, num_cores_per_replica = 1, topology = "topology", device_assignment = [], padding_map = []} : () -> ()
  %replicated_input = "tf.TPUReplicatedInput"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
  %identity = "tf.Identity"(%replicated_input) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
  %replicated_output:2 = "tf.TPUReplicatedOutput(%identity) : (tensor<i32>) -> (tensor<i32>, tensor<i32>)
  return %replicated_output#0, %replicated_output#1 : tensor<i32>, tensor<i32>
}

will be transformed into:

func @tpu_computation(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
  %replicate:2 = tf_device.replicate([%arg0, %arg1] as %replicated_input) {n = 2 : i32} {
    %cluster = "tf_device.cluster"() ( {
      %identity = "tf.Identity"(%replicated_input) : (tensor<i32>) -> tensor<i32>
      tf_device.return %identity : tensor<i32>
    }) {_tpu_replicate = "cluster", num_cores_per_replica = 1, topology = "topology", device_assignment = [], padding_map = []} : () -> (tensor<i32>)
    tf_device.return %cluster : tensor<i32>
  }
  return %replicate#0, %replicate#1 : tensor<i32>, tensor<i32>
}

-tf-tpu-extract-outside-compilation: Extracts TPU outside compilation computation to a separate tf_device.parallel_execute region.

This pass extracts a CPU computation cluster with _xla_outside_compilation annotation, which denotes ops that should be run on CPU/host, from a TPU cluster. Each outside compilation cluster is moved to a tf_device.parallel_execute region. The TPU cluster is also moved to a tf_device.parallel_execute region. Communication ops between device and host are added to pass inputs/outputs to/from the outside compiled region.

For example, the following tf_device.cluster with an op marked for xla_outside_compilation:

func @outside_compilation() -> tensor<f32> {
  %0 = "tf_device.cluster"() ( {
    %1 = "tf.Const"() {_xla_outside_compilation = "0", value = dense<1.0> : tensor<f32>} : () -> (tensor<f32>)
    %2 = "tf.Identity"(%1) {_xla_outside_compilation = "0"} : (tensor<f32>) -> (tensor<f32>)
    %3 = "tf.AddV2"(%1, %2) : (tensor<f32>, tensor<f32>) -> (tensor<f32>)
    tf_device.return %3 : tensor<f32>
  }) {num_cores_per_replica = 1, topology =  "", device_assignment =  []} : () -> tensor<f32>
  return %0 : tensor<f32>
}

will become a tf_device.parallel_execute op with a CPU/host region and a tf_device.cluster with communication ops to send data to/from device/host:

func @outside_compilation() -> tensor<f32> {
  %0 = "tf_device.parallel_execute"() ( {
    "tf_device.launch"() ( {
      %1 = "tf._TPUCompileMlirPlaceholderProgramKey"() : () -> tensor<3x!tf.string>
      %2 = "tf._XlaRecvAtHost"(%1) {device_ordinal = 0 : i64, key = "host_compute_channel_0_0_args"} : (tensor<3x!tf.string>) -> tensor<f32>
      %3 = "tf.Identity"(%2) : (tensor<f32>) -> tensor<f32>
      "tf._XlaSendFromHost"(%3, %1) {device_ordinal = 0 : i64, key = "host_compute_channel_0_0_retvals"} : (tensor<f32>, tensor<3x!tf.string>) -> ()
      tf_device.return
    }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> ()
    tf_device.return
  },  {
    %1 = "tf_device.cluster"() ( {
      %2 = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
      %3 = "tf._XlaHostComputeMlir"(%2) {recv_key = "host_compute_channel_0_0_retvals", send_key = "host_compute_channel_0_0_args", tpu_core = 0 : i64} : (tensor<f32>) -> tensor<f32>
      %4 = "tf.AddV2"(%2, %3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
      tf_device.return %4 : tensor<f32>
    }) {device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""} : () -> tensor<f32>
    tf_device.return %1 : tensor<f32>
  }) : () -> tensor<f32>
  return %0 : tensor<f32>
}

-tf-tpu-resource-partition: Partitions unpartitioned resource read/write to partitioned resource variables.

This pass creates individual resource reads/writes from the unpartitioned resource variable (from tf.TPUPartitionedInput) to individual partitioned resource variables (tf.TPUPartitionedInput operands). As resource op decomposition/lifting occurs with the unpartitioned resource variables, transforming the IR in such a manner will allow for subsequent passes to operate on individual resource variable handles per core/device.

For example, the following:

func @cluster(%arg0: tensor<!tf.resource<tensor<i32>>>, %arg1: tensor<!tf.resource<tensor<i32>>>) {
  %partitioned_variable = "tf.TPUPartitionedInput"(%arg0, %arg1) {N = 2 : i64, _XlaSharding = "", partition_dim = -1 : i64} : (tensor<!tf.resource<tensor<i32>>>, tensor<!tf.resource<tensor<i32>>>) -> tensor<!tf.resource<tensor<i32>>>
  %read = "tf.ReadVariableOp"(%partitioned_variable) : (tensor<!tf.resource<tensor<i32>>>) -> tensor<i32>
  %computation = "tf_device.cluster_func"(%read) {func = @computation, use_spmd_for_xla_partitioning = true} : (tensor<i32>) -> tensor<i32>
  "tf.AssignVariableOp"(%partitioned_variable, %computation) : (tensor<!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
  return
}

func @computation(%arg0: tensor<i32>) -> tensor<i32> {
  return %arg0: tensor<i32>
}

will be transformed into:

func @cluster(%arg0: tensor<!tf.resource<tensor<i32>>>, %arg1: tensor<!tf.resource<tensor<i32>>>) {
  %read0 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<i32>>>) -> tensor<i32>
  %read1 = "tf.ReadVariableOp"(%arg1) : (tensor<!tf.resource<tensor<i32>>>) -> tensor<i32>
  %partitioned_input = "tf.TPUPartitionedInput"(%read0, %read1) {N = 2 : i64, _XlaSharding = "", partition_dim = -1 : i64} : (tensor<i32>, tensor<i32>) -> tensor<i32>
  %computation = "tf_device.cluster_func"(%partitioned_input) {func = @computation, use_spmd_for_xla_partitioning = true} : (tensor<i32>) -> tensor<i32>
  %partitioned_output:2 = "tf.TPUPartitionedOutput"(%computation) {N = 2 : i64, _XlaSharding = "", partition_dim = -1 : i64} : (tensor<i32>) -> (tensor<i32>, tensor<i32>)
  "tf.AssignVariableOp"(%arg0, %partitioned_output#0) : (tensor<!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
  "tf.AssignVariableOp"(%arg1, %partitioned_output#1) : (tensor<!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
  return
}

func @computation(%arg0: tensor<i32>) -> tensor<i32> {
  return %arg0: tensor<i32>
}

-tf-tpu-resource-read-for-write: Inserts tf.ReadVariableOp inputs to a TPU cluster for resource writes with no reads

This pass materializes tf.ReadVariableOp inputs to an outlined TPU computation for resource variables where only writes are present so later in the pipeline such resource variables can be fused with generated tf.TPUExecute ops, which only supports resource variable read or read + write. For all TPU computations, resource variables are required to be initialized prior to execution. Write only resource variable uses can be generated currently via packed tensor uses.

For example, the following:

func @write_only_resource(%value: tensor<i32>, %resource: tensor<*x!tf.resource<tensor<i32>>>) {
  %0 = "tf_device.cluster_func"(%value) {func = @cluster} : (tensor<i32>) -> tensor<i32>
  "tf.AssignVariableOp"(%resource, %0) : (tensor<*x!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
  return
}

func @cluster(%arg0: tensor<i32>) -> tensor<i32> {
  %identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
  return %identity : tensor<i32>
}

will be transformed into:

func @write_only_resource(%value: tensor<i32>, %resource: tensor<*x!tf.resource<tensor<i32>>>) {
  %resource_read = "tf.ReadVariableOp"(%resource) : (tensor<*x!tf.resource<tensor<i32>>>) -> tensor<i32>
  %0 = "tf_device.cluster_func"(%value, %resource_read) {func = @cluster} : (tensor<i32>, tensor<i32>) -> tensor<i32>
  "tf.AssignVariableOp"(%resource, %0) : (tensor<*x!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
  return
}

func @cluster(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
  %identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
  return %identity : tensor<i32>
}

-tf-tpu-rewrite: Rewrites a tf_device.cluster_func on TPUs into TPU runtime operations.

This pass rewrites a tf_device.cluster_func operation into a sequence of tf._TPUCompileMlir and tf.TPUExecute operations. tf._TPUCompileMlir contains a MLIR module that is functionally equivalent to the function referenced by tf_device.cluster_func. This makes the module to be jit-compiled and executed on TPU. If it is not possible to rewrite the operation or device assignment fails, a failure will be returned.

Note, many parameters to the tf_device.cluster_func are ommited in this and following examples. For example, a non replicated tf_device.cluster_func:

func @tf_tpu_rewrite(%arg0: tensor<i8>) {
  %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @func} : (tensor<i8>) -> tensor<i8>
  return
}

will be rewritten as:

func @tf_tpu_rewrite(%arg0: tensor<i8>) {
  %0:2 = "tf_device.launch"() ( {
    %compilation_status, %program = "tf._TPUCompileMlir"() {mlir_module = "<serialized func>"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>)
    tf_device.return %compilation_status, %program : tensor<!tf.string>, tensor<3x!tf.string>
  }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>)
  "tf_device.launch"() ( {
    "tf.TPUCompileSucceededAssert"(%0#0) : (tensor<!tf.string>) -> ()
    tf_device.return
  }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> ()
  %1 = "tf_device.launch"() ( {
    %2 = "tf.TPUExecute"(%arg0, %0#1) : (tensor<i8>, tensor<3x!tf.string>) -> tensor<i8>
    tf_device.return %2 : tensor<i8>
  }) {device = "/job:worker/replica:0/task:0/device:TPU:0"} : () -> tensor<i8>
  return
}

A replicated tf_device.cluster_func:

func @tf_tpu_rewrite(%arg0: tensor<i8>, %arg1: tensor<i8>) {
  %0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<i8>) {n = 2 : i32} {
    %1 = "tf_device.cluster_func"(%ri) {_tpu_replicate = "cluster0", func = @func} : (tensor<i8>) -> tensor<i8>
    tf_device.return %1 : tensor<i8>
  }
  return
}

will be rewritten as:

func @tf_tpu_rewrite(%arg0: tensor<i8>, %arg1: tensor<i8>) {
  %0:2 = tf_device.replicate([%arg0, %arg1] as %arg2: tensor<i8>) {devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"], TPU_REPLICATED_HOST = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:CPU:0"]}, n = 2 : i32} {
    %1:2 = "tf_device.launch"() ( {
      %compilation_status, %program = "tf._TPUCompileMlir"() {mlir_module = "<serialized func>"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>)
      tf_device.return %compilation_status, %program : tensor<!tf.string>, tensor<3x!tf.string>
    }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>)
    "tf_device.launch"() ( {
      "tf.TPUCompileSucceededAssert"(%1#0) : (tensor<!tf.string>) -> ()
      tf_device.return
    }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> ()
    %2 = "tf_device.launch"() ( {
      %3 = "tf.TPUExecute"(%arg2, %1#1) : (tensor<i8>, tensor<3x!tf.string>) -> tensor<i8>
      tf_device.return %3 : tensor<i8>
    }) {device = "TPU_REPLICATED_CORE_0"} : () -> tensor<i8>
    tf_device.return %2 : tensor<i8>
  }
  return
}

A non replicated `tf_device.cluster_func` with the model parallelism:

```mlir
func @tf_tpu_rewrite(%arg0: tensor<8xi32>) -> tensor<8xi32> {
  %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @func, num_cores_per_replica = 2, input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>) -> tensor<8xi32>
  return %0 : tensor<8xi32>
}

will be rewritten as:

func @tf_tpu_rewrite(%arg0: tensor<8xi32>) -> tensor<8xi32> {
  %0:3 = "tf_device.launch"() ( {
    %compilation_status, %program:2 = "tf._TPUCompileMlir"() {mlir_module = "<serialized func>"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>, tensor<3x!tf.string>)
    tf_device.return %compilation_status, %program#0, %program#1 : tensor<!tf.string>, tensor<3x!tf.string>, tensor<3x!tf.string>
  }) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>, tensor<3x!tf.string>)
  "tf_device.launch"() ( {
    "tf.TPUCompileSucceededAssert"(%0#0) : (tensor<!tf.string>) -> ()
    tf_device.return
  }) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> ()
  %1 = "tf_device.parallel_execute"() ( {
    %2 = "tf_device.launch"() ( {
      %3 = "tf.TPUExecute"(%arg0, %0#1) : (tensor<8xi32>, tensor<3x!tf.string>) -> tensor<8xi32>
      tf_device.return %3 : tensor<8xi32>
    }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> tensor<8xi32>
    tf_device.return %2 : tensor<8xi32>
  },  {
    "tf_device.launch"() ( {
      "tf.TPUExecute"(%0#2) : (tensor<3x!tf.string>) -> ()
      tf_device.return
    }) {device = "/job:localhost/replica:0/task:0/device:TPU:1"} : () -> ()
    tf_device.return
  }) : () -> tensor<8xi32>
  return %1 : tensor<8xi32>
}