Op documentation update.
	update of g3doc/includes/tf_passes.md

PiperOrigin-RevId: 349495105
Change-Id: I399cf816b010dea67723c4ef72f76f1fcea13053
diff --git a/tensorflow/compiler/mlir/g3doc/includes/tf_passes.md b/tensorflow/compiler/mlir/g3doc/includes/tf_passes.md
index df5b1bf..a9485a1 100644
--- a/tensorflow/compiler/mlir/g3doc/includes/tf_passes.md
+++ b/tensorflow/compiler/mlir/g3doc/includes/tf_passes.md
@@ -430,3 +430,113 @@
   return %identity : tensor<i32>
 }
 ```
+### `-tf-tpu-rewrite`: Rewrites a `tf_device.cluster_func` on TPUs into TPU runtime operations.
+This pass rewrites a `tf_device.cluster_func` operation into a sequence of `tf._TPUCompileMlir`
+and `tf.TPUExecute` operations. `tf._TPUCompileMlir` contains a MLIR module that is
+functionally equivalent to the function referenced by `tf_device.cluster_func`.
+This makes the module to be jit-compiled and executed on TPU.
+If it is not possible to rewrite the operation or device assignment fails,
+a failure will be returned.
+
+Note, many parameters to the `tf_device.cluster_func` are ommited in this
+and following examples.
+For example, a non replicated `tf_device.cluster_func`:
+
+```mlir
+func @tf_tpu_rewrite(%arg0: tensor<i8>) {
+  %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @func} : (tensor<i8>) -> tensor<i8>
+  return
+}
+```
+
+will be rewritten as:
+
+```mlir
+func @tf_tpu_rewrite(%arg0: tensor<i8>) {
+  %0:2 = "tf_device.launch"() ( {
+    %compilation_status, %program = "tf._TPUCompileMlir"() {mlir_module = "<serialized func>"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>)
+    tf_device.return %compilation_status, %program : tensor<!tf.string>, tensor<3x!tf.string>
+  }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>)
+  "tf_device.launch"() ( {
+    "tf.TPUCompileSucceededAssert"(%0#0) : (tensor<!tf.string>) -> ()
+    tf_device.return
+  }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> ()
+  %1 = "tf_device.launch"() ( {
+    %2 = "tf.TPUExecute"(%arg0, %0#1) : (tensor<i8>, tensor<3x!tf.string>) -> tensor<i8>
+    tf_device.return %2 : tensor<i8>
+  }) {device = "/job:worker/replica:0/task:0/device:TPU:0"} : () -> tensor<i8>
+  return
+}
+```
+
+A replicated `tf_device.cluster_func`:
+
+```mlir
+func @tf_tpu_rewrite(%arg0: tensor<i8>, %arg1: tensor<i8>) {
+  %0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<i8>) {n = 2 : i32} {
+    %1 = "tf_device.cluster_func"(%ri) {_tpu_replicate = "cluster0", func = @func} : (tensor<i8>) -> tensor<i8>
+    tf_device.return %1 : tensor<i8>
+  }
+  return
+}
+```
+
+will be rewritten as:
+
+```mlir
+func @tf_tpu_rewrite(%arg0: tensor<i8>, %arg1: tensor<i8>) {
+  %0:2 = tf_device.replicate([%arg0, %arg1] as %arg2: tensor<i8>) {devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"], TPU_REPLICATED_HOST = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:CPU:0"]}, n = 2 : i32} {
+    %1:2 = "tf_device.launch"() ( {
+      %compilation_status, %program = "tf._TPUCompileMlir"() {mlir_module = "<serialized func>"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>)
+      tf_device.return %compilation_status, %program : tensor<!tf.string>, tensor<3x!tf.string>
+    }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>)
+    "tf_device.launch"() ( {
+      "tf.TPUCompileSucceededAssert"(%1#0) : (tensor<!tf.string>) -> ()
+      tf_device.return
+    }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> ()
+    %2 = "tf_device.launch"() ( {
+      %3 = "tf.TPUExecute"(%arg2, %1#1) : (tensor<i8>, tensor<3x!tf.string>) -> tensor<i8>
+      tf_device.return %3 : tensor<i8>
+    }) {device = "TPU_REPLICATED_CORE_0"} : () -> tensor<i8>
+    tf_device.return %2 : tensor<i8>
+  }
+  return
+}
+
+A non replicated `tf_device.cluster_func` with the model parallelism:
+
+```mlir
+func @tf_tpu_rewrite(%arg0: tensor<8xi32>) -> tensor<8xi32> {
+  %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @func, num_cores_per_replica = 2, input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>) -> tensor<8xi32>
+  return %0 : tensor<8xi32>
+}
+```
+
+will be rewritten as:
+
+```mlir
+func @tf_tpu_rewrite(%arg0: tensor<8xi32>) -> tensor<8xi32> {
+  %0:3 = "tf_device.launch"() ( {
+    %compilation_status, %program:2 = "tf._TPUCompileMlir"() {mlir_module = "<serialized func>"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>, tensor<3x!tf.string>)
+    tf_device.return %compilation_status, %program#0, %program#1 : tensor<!tf.string>, tensor<3x!tf.string>, tensor<3x!tf.string>
+  }) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>, tensor<3x!tf.string>)
+  "tf_device.launch"() ( {
+    "tf.TPUCompileSucceededAssert"(%0#0) : (tensor<!tf.string>) -> ()
+    tf_device.return
+  }) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> ()
+  %1 = "tf_device.parallel_execute"() ( {
+    %2 = "tf_device.launch"() ( {
+      %3 = "tf.TPUExecute"(%arg0, %0#1) : (tensor<8xi32>, tensor<3x!tf.string>) -> tensor<8xi32>
+      tf_device.return %3 : tensor<8xi32>
+    }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> tensor<8xi32>
+    tf_device.return %2 : tensor<8xi32>
+  },  {
+    "tf_device.launch"() ( {
+      "tf.TPUExecute"(%0#2) : (tensor<3x!tf.string>) -> ()
+      tf_device.return
+    }) {device = "/job:localhost/replica:0/task:0/device:TPU:1"} : () -> ()
+    tf_device.return
+  }) : () -> tensor<8xi32>
+  return %1 : tensor<8xi32>
+}
+```