Update EraseClusterFuncs in TPU Rewrite pass to pass asan.
PiperOrigin-RevId: 349434078
Change-Id: Ib26b7d12419c70e98c4c1a5c02b86bdd141658ad
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc
index 4eba57e..c5a290d 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc
@@ -14,7 +14,6 @@
==============================================================================*/
#include <cstdint>
-#include <deque>
#include <string>
#include <type_traits>
@@ -636,11 +635,6 @@
tf_device::ClusterFuncOp cluster_func,
llvm::ArrayRef<tensorflow::DeviceNameUtils::ParsedName> devices,
OpBuilder* builder) {
- // Skip non-tpu device cluster_func.
- auto replicate_attr =
- cluster_func->getAttrOfType<StringAttr>("_tpu_replicate");
- if (!replicate_attr) return success();
-
// Collect `num_replicas` and `num_cores_per_replica` attributes.
int num_replicas = 1;
tf_device::ReplicateOp replicate =
@@ -776,23 +770,29 @@
// Erase rewritten ClusterFuncOp(s). If TPUPartitionedInputOp /
// TPUPartitionedOutputOp are present, they must be removed alongwith the
// ClusterFuncOp(s).
-void EraseClusterFuncs(std::deque<Operation*>& to_be_erased) {
+void EraseClusterFuncs(
+ llvm::MutableArrayRef<tf_device::ClusterFuncOp> to_be_erased) {
for (auto cluster : to_be_erased) {
- // Add TPUPartitionedInputOp inputs to ClusterFuncOp to be removed at the
- // end.
- for (auto operand : cluster->getOperands()) {
- if (llvm::isa_and_nonnull<TF::TPUPartitionedInputOp>(
- operand.getDefiningOp()))
- to_be_erased.emplace_back(operand.getDefiningOp());
+ for (auto result : cluster.results()) {
+ for (Operation* user : llvm::make_early_inc_range(result.getUsers())) {
+ if (llvm::isa<TF::TPUPartitionedOutputOp>(user)) {
+ assert(user->use_empty());
+ user->erase();
+ }
+ }
}
- // Add TPUPartitionedOutputOp users of ClusterFuncOp to be removed first.
- for (auto user : cluster->getUsers()) {
- if (llvm::isa<TF::TPUPartitionedOutputOp>(user))
- to_be_erased.emplace_front(user);
+
+ for (auto operand : cluster.operands()) {
+ Operation* def = operand.getDefiningOp();
+ if (operand.hasOneUse() &&
+ llvm::isa_and_nonnull<TF::TPUPartitionedInputOp>(def)) {
+ operand.dropAllUses();
+ def->erase();
+ }
}
- }
- for (auto op : to_be_erased) {
- if (op->use_empty()) op->erase();
+
+ assert(cluster->use_empty());
+ cluster->erase();
}
}
@@ -801,13 +801,17 @@
if (failed(tensorflow::GetDevicesFromOp(getOperation(), &devices)))
return signalPassFailure();
- std::deque<Operation*> to_be_erased;
+ llvm::SmallVector<tf_device::ClusterFuncOp> to_be_erased;
OpBuilder builder(&getContext());
auto result = getOperation().walk([&](tf_device::ClusterFuncOp op) {
+ // Skip non-tpu device cluster_func.
+ auto replicate_attr = op->getAttrOfType<StringAttr>("_tpu_replicate");
+ if (!replicate_attr) return WalkResult::advance();
+
if (failed(Rewrite(op, devices.device_names(), &builder)))
return WalkResult::interrupt();
- to_be_erased.emplace_back(op.getOperation());
+ to_be_erased.push_back(op);
return WalkResult::advance();
});