[MLIR] Add a pass to cleanup attributes on operations within a cluster

- This pass will cleanup attributes like _tpu_replicate and device on operations that are
  within a cluster (similar to what cluster formation does)
- Intended to be used in the MLIR->XLA bridge post inlining following functional->region
  control flow.

PiperOrigin-RevId: 330735644
Change-Id: I1c67995c13aa2c5d3a110d2ea595e494b22a7558
diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD
index 98f4fd2..6441319 100644
--- a/tensorflow/compiler/mlir/tensorflow/BUILD
+++ b/tensorflow/compiler/mlir/tensorflow/BUILD
@@ -787,6 +787,7 @@
         "transforms/test_visitor_util.cc",
         "transforms/tf_data_optimization_pass.cc",
         "transforms/tf_device_assignment.cc",
+        "transforms/tpu_cluster_cleanup_attributes.cc",
         "transforms/tpu_cluster_formation.cc",
         "transforms/tpu_colocate_composite_resource_ops.cc",
         "transforms/tpu_dynamic_layout_pass.cc",
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu-cluster-cleanup-attributes.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu-cluster-cleanup-attributes.mlir
new file mode 100644
index 0000000..6399d7d
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu-cluster-cleanup-attributes.mlir
@@ -0,0 +1,24 @@
+// RUN: tf-opt %s -tf-tpu-cleanup-cluster-attributes | FileCheck %s
+
+func @test(%arg0: tensor<i1>, %arg1: tensor<f32>, %arg2: tensor<f32>) ->  tensor<f32> {
+  // CHECK: "tf_device.cluster"
+  // CHECK-NOT: _tpu_replicate =
+  // CHECK-NOT: device =
+  %1 = "tf_device.cluster"() ( {
+    %2 = "tf.Add"(%arg1, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    %3 = "tf.IfRegion"(%arg0) ({
+        %4 = "tf.Mul" (%arg1, %2) {device = "y"}: (tensor<f32>, tensor<f32>) -> tensor<f32>
+        "tf.Yield"(%4) : (tensor<f32>) -> ()
+      }, {
+        %5 = "tf.Div" (%arg1, %2) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+        "tf.Yield"(%5) : (tensor<f32>) -> ()
+      }) {is_stateless = true, _tpu_replicate = "x" } : (tensor<i1>) -> (tensor<f32>)
+    tf_device.return %3 : tensor<f32>
+  // CHECK: {_tpu_replicate = "x", cluster_attr = "cluster_attr", device = "y"}
+  }) {cluster_attr = "cluster_attr", _tpu_replicate = "x", device = "y"} : () -> tensor<f32>
+  // CHECK: "tf.Add"
+  // CHECK-SAME: {_tpu_replicate = "x", device = "y"}
+  %2 = "tf.Add"(%arg2, %1) {_tpu_replicate = "x", device = "y"} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  // CHECK: return
+  return %2 : tensor<f32>
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
index dbb48dc..7dad109 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
@@ -276,6 +276,11 @@
 // `_tpu_replicate` attribute.
 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUClusterFormationPass();
 
+// Creates a pass that cleans up `_tpu_replicate` attribute on operations
+// that are inside a cluster.
+std::unique_ptr<OperationPass<ModuleOp>>
+CreateTPUClusterCleanupAttributesPass();
+
 // Creates a pass that removes Identity/IdentityN ops from a cluster.
 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUIdentityPruningPass();
 
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_cleanup_attributes.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_cleanup_attributes.cc
new file mode 100644
index 0000000..93098ac
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_cleanup_attributes.cc
@@ -0,0 +1,60 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "mlir/IR/Module.h"  // from @llvm-project
+#include "mlir/Pass/PassManager.h"  // from @llvm-project
+#include "mlir/Transforms/Passes.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
+#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
+
+// This pass eliminate `_tpu_replicate` and `device` attribute on operations
+// that are contained in a tf_device.cluster op.
+
+namespace mlir {
+namespace TFTPU {
+
+namespace {
+
+constexpr char kTPUReplicateAttr[] = "_tpu_replicate";
+constexpr char kDeviceAttr[] = "device";
+
+class TPUCleanupClusterAttributesPass
+    : public PassWrapper<TPUCleanupClusterAttributesPass,
+                         OperationPass<ModuleOp>> {
+ public:
+  void runOnOperation() override {
+    getOperation().walk([](tf_device::ClusterOp cluster) {
+      cluster.walk([](Operation *op) {
+        if (isa<tf_device::ClusterOp>(op)) return;
+        for (StringRef attr : {kTPUReplicateAttr, kDeviceAttr})
+          op->removeAttr(attr);
+      });
+    });
+  }
+};
+
+PassRegistration<TPUCleanupClusterAttributesPass> pass(
+    "tf-tpu-cleanup-cluster-attributes",
+    "Eliminate _tpu_replicate and other attributes from ops in a cluster");
+
+}  // namespace
+
+std::unique_ptr<OperationPass<ModuleOp>>
+CreateTPUClusterCleanupAttributesPass() {
+  return std::make_unique<TPUCleanupClusterAttributesPass>();
+}
+
+}  // namespace TFTPU
+}  // namespace mlir