Op documentation update.
update of g3doc/includes/tf_passes.md
PiperOrigin-RevId: 349495105
Change-Id: I399cf816b010dea67723c4ef72f76f1fcea13053
diff --git a/tensorflow/compiler/mlir/g3doc/includes/tf_passes.md b/tensorflow/compiler/mlir/g3doc/includes/tf_passes.md
index df5b1bf..a9485a1 100644
--- a/tensorflow/compiler/mlir/g3doc/includes/tf_passes.md
+++ b/tensorflow/compiler/mlir/g3doc/includes/tf_passes.md
@@ -430,3 +430,113 @@
return %identity : tensor<i32>
}
```
+### `-tf-tpu-rewrite`: Rewrites a `tf_device.cluster_func` on TPUs into TPU runtime operations.
+This pass rewrites a `tf_device.cluster_func` operation into a sequence of `tf._TPUCompileMlir`
+and `tf.TPUExecute` operations. `tf._TPUCompileMlir` contains a MLIR module that is
+functionally equivalent to the function referenced by `tf_device.cluster_func`.
+This makes the module to be jit-compiled and executed on TPU.
+If it is not possible to rewrite the operation or device assignment fails,
+a failure will be returned.
+
+Note, many parameters to the `tf_device.cluster_func` are ommited in this
+and following examples.
+For example, a non replicated `tf_device.cluster_func`:
+
+```mlir
+func @tf_tpu_rewrite(%arg0: tensor<i8>) {
+ %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @func} : (tensor<i8>) -> tensor<i8>
+ return
+}
+```
+
+will be rewritten as:
+
+```mlir
+func @tf_tpu_rewrite(%arg0: tensor<i8>) {
+ %0:2 = "tf_device.launch"() ( {
+ %compilation_status, %program = "tf._TPUCompileMlir"() {mlir_module = "<serialized func>"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>)
+ tf_device.return %compilation_status, %program : tensor<!tf.string>, tensor<3x!tf.string>
+ }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>)
+ "tf_device.launch"() ( {
+ "tf.TPUCompileSucceededAssert"(%0#0) : (tensor<!tf.string>) -> ()
+ tf_device.return
+ }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> ()
+ %1 = "tf_device.launch"() ( {
+ %2 = "tf.TPUExecute"(%arg0, %0#1) : (tensor<i8>, tensor<3x!tf.string>) -> tensor<i8>
+ tf_device.return %2 : tensor<i8>
+ }) {device = "/job:worker/replica:0/task:0/device:TPU:0"} : () -> tensor<i8>
+ return
+}
+```
+
+A replicated `tf_device.cluster_func`:
+
+```mlir
+func @tf_tpu_rewrite(%arg0: tensor<i8>, %arg1: tensor<i8>) {
+ %0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<i8>) {n = 2 : i32} {
+ %1 = "tf_device.cluster_func"(%ri) {_tpu_replicate = "cluster0", func = @func} : (tensor<i8>) -> tensor<i8>
+ tf_device.return %1 : tensor<i8>
+ }
+ return
+}
+```
+
+will be rewritten as:
+
+```mlir
+func @tf_tpu_rewrite(%arg0: tensor<i8>, %arg1: tensor<i8>) {
+ %0:2 = tf_device.replicate([%arg0, %arg1] as %arg2: tensor<i8>) {devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"], TPU_REPLICATED_HOST = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:CPU:0"]}, n = 2 : i32} {
+ %1:2 = "tf_device.launch"() ( {
+ %compilation_status, %program = "tf._TPUCompileMlir"() {mlir_module = "<serialized func>"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>)
+ tf_device.return %compilation_status, %program : tensor<!tf.string>, tensor<3x!tf.string>
+ }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>)
+ "tf_device.launch"() ( {
+ "tf.TPUCompileSucceededAssert"(%1#0) : (tensor<!tf.string>) -> ()
+ tf_device.return
+ }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> ()
+ %2 = "tf_device.launch"() ( {
+ %3 = "tf.TPUExecute"(%arg2, %1#1) : (tensor<i8>, tensor<3x!tf.string>) -> tensor<i8>
+ tf_device.return %3 : tensor<i8>
+ }) {device = "TPU_REPLICATED_CORE_0"} : () -> tensor<i8>
+ tf_device.return %2 : tensor<i8>
+ }
+ return
+}
+
+A non replicated `tf_device.cluster_func` with the model parallelism:
+
+```mlir
+func @tf_tpu_rewrite(%arg0: tensor<8xi32>) -> tensor<8xi32> {
+ %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @func, num_cores_per_replica = 2, input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>) -> tensor<8xi32>
+ return %0 : tensor<8xi32>
+}
+```
+
+will be rewritten as:
+
+```mlir
+func @tf_tpu_rewrite(%arg0: tensor<8xi32>) -> tensor<8xi32> {
+ %0:3 = "tf_device.launch"() ( {
+ %compilation_status, %program:2 = "tf._TPUCompileMlir"() {mlir_module = "<serialized func>"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>, tensor<3x!tf.string>)
+ tf_device.return %compilation_status, %program#0, %program#1 : tensor<!tf.string>, tensor<3x!tf.string>, tensor<3x!tf.string>
+ }) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>, tensor<3x!tf.string>)
+ "tf_device.launch"() ( {
+ "tf.TPUCompileSucceededAssert"(%0#0) : (tensor<!tf.string>) -> ()
+ tf_device.return
+ }) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> ()
+ %1 = "tf_device.parallel_execute"() ( {
+ %2 = "tf_device.launch"() ( {
+ %3 = "tf.TPUExecute"(%arg0, %0#1) : (tensor<8xi32>, tensor<3x!tf.string>) -> tensor<8xi32>
+ tf_device.return %3 : tensor<8xi32>
+ }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> tensor<8xi32>
+ tf_device.return %2 : tensor<8xi32>
+ }, {
+ "tf_device.launch"() ( {
+ "tf.TPUExecute"(%0#2) : (tensor<3x!tf.string>) -> ()
+ tf_device.return
+ }) {device = "/job:localhost/replica:0/task:0/device:TPU:1"} : () -> ()
+ tf_device.return
+ }) : () -> tensor<8xi32>
+ return %1 : tensor<8xi32>
+}
+```