Update EraseClusterFuncs in TPU Rewrite pass to pass asan.

PiperOrigin-RevId: 349434078
Change-Id: Ib26b7d12419c70e98c4c1a5c02b86bdd141658ad
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc
index 4eba57e..c5a290d 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc
@@ -14,7 +14,6 @@
 ==============================================================================*/
 
 #include <cstdint>
-#include <deque>
 #include <string>
 #include <type_traits>
 
@@ -636,11 +635,6 @@
     tf_device::ClusterFuncOp cluster_func,
     llvm::ArrayRef<tensorflow::DeviceNameUtils::ParsedName> devices,
     OpBuilder* builder) {
-  // Skip non-tpu device cluster_func.
-  auto replicate_attr =
-      cluster_func->getAttrOfType<StringAttr>("_tpu_replicate");
-  if (!replicate_attr) return success();
-
   // Collect `num_replicas` and `num_cores_per_replica` attributes.
   int num_replicas = 1;
   tf_device::ReplicateOp replicate =
@@ -776,23 +770,29 @@
 // Erase rewritten ClusterFuncOp(s). If TPUPartitionedInputOp /
 // TPUPartitionedOutputOp are present, they must be removed alongwith the
 // ClusterFuncOp(s).
-void EraseClusterFuncs(std::deque<Operation*>& to_be_erased) {
+void EraseClusterFuncs(
+    llvm::MutableArrayRef<tf_device::ClusterFuncOp> to_be_erased) {
   for (auto cluster : to_be_erased) {
-    // Add TPUPartitionedInputOp inputs to ClusterFuncOp to be removed at the
-    // end.
-    for (auto operand : cluster->getOperands()) {
-      if (llvm::isa_and_nonnull<TF::TPUPartitionedInputOp>(
-              operand.getDefiningOp()))
-        to_be_erased.emplace_back(operand.getDefiningOp());
+    for (auto result : cluster.results()) {
+      for (Operation* user : llvm::make_early_inc_range(result.getUsers())) {
+        if (llvm::isa<TF::TPUPartitionedOutputOp>(user)) {
+          assert(user->use_empty());
+          user->erase();
+        }
+      }
     }
-    // Add TPUPartitionedOutputOp users of ClusterFuncOp to be removed first.
-    for (auto user : cluster->getUsers()) {
-      if (llvm::isa<TF::TPUPartitionedOutputOp>(user))
-        to_be_erased.emplace_front(user);
+
+    for (auto operand : cluster.operands()) {
+      Operation* def = operand.getDefiningOp();
+      if (operand.hasOneUse() &&
+          llvm::isa_and_nonnull<TF::TPUPartitionedInputOp>(def)) {
+        operand.dropAllUses();
+        def->erase();
+      }
     }
-  }
-  for (auto op : to_be_erased) {
-    if (op->use_empty()) op->erase();
+
+    assert(cluster->use_empty());
+    cluster->erase();
   }
 }
 
@@ -801,13 +801,17 @@
   if (failed(tensorflow::GetDevicesFromOp(getOperation(), &devices)))
     return signalPassFailure();
 
-  std::deque<Operation*> to_be_erased;
+  llvm::SmallVector<tf_device::ClusterFuncOp> to_be_erased;
   OpBuilder builder(&getContext());
   auto result = getOperation().walk([&](tf_device::ClusterFuncOp op) {
+    // Skip non-tpu device cluster_func.
+    auto replicate_attr = op->getAttrOfType<StringAttr>("_tpu_replicate");
+    if (!replicate_attr) return WalkResult::advance();
+
     if (failed(Rewrite(op, devices.device_names(), &builder)))
       return WalkResult::interrupt();
 
-    to_be_erased.emplace_back(op.getOperation());
+    to_be_erased.push_back(op);
     return WalkResult::advance();
   });