[MLIR] Add a pass to cleanup attributes on operations within a cluster
- This pass will cleanup attributes like _tpu_replicate and device on operations that are
within a cluster (similar to what cluster formation does)
- Intended to be used in the MLIR->XLA bridge post inlining following functional->region
control flow.
PiperOrigin-RevId: 330735644
Change-Id: I1c67995c13aa2c5d3a110d2ea595e494b22a7558
diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD
index 98f4fd2..6441319 100644
--- a/tensorflow/compiler/mlir/tensorflow/BUILD
+++ b/tensorflow/compiler/mlir/tensorflow/BUILD
@@ -787,6 +787,7 @@
"transforms/test_visitor_util.cc",
"transforms/tf_data_optimization_pass.cc",
"transforms/tf_device_assignment.cc",
+ "transforms/tpu_cluster_cleanup_attributes.cc",
"transforms/tpu_cluster_formation.cc",
"transforms/tpu_colocate_composite_resource_ops.cc",
"transforms/tpu_dynamic_layout_pass.cc",
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu-cluster-cleanup-attributes.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu-cluster-cleanup-attributes.mlir
new file mode 100644
index 0000000..6399d7d
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu-cluster-cleanup-attributes.mlir
@@ -0,0 +1,24 @@
+// RUN: tf-opt %s -tf-tpu-cleanup-cluster-attributes | FileCheck %s
+
+func @test(%arg0: tensor<i1>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> {
+ // CHECK: "tf_device.cluster"
+ // CHECK-NOT: _tpu_replicate =
+ // CHECK-NOT: device =
+ %1 = "tf_device.cluster"() ( {
+ %2 = "tf.Add"(%arg1, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ %3 = "tf.IfRegion"(%arg0) ({
+ %4 = "tf.Mul" (%arg1, %2) {device = "y"}: (tensor<f32>, tensor<f32>) -> tensor<f32>
+ "tf.Yield"(%4) : (tensor<f32>) -> ()
+ }, {
+ %5 = "tf.Div" (%arg1, %2) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ "tf.Yield"(%5) : (tensor<f32>) -> ()
+ }) {is_stateless = true, _tpu_replicate = "x" } : (tensor<i1>) -> (tensor<f32>)
+ tf_device.return %3 : tensor<f32>
+ // CHECK: {_tpu_replicate = "x", cluster_attr = "cluster_attr", device = "y"}
+ }) {cluster_attr = "cluster_attr", _tpu_replicate = "x", device = "y"} : () -> tensor<f32>
+ // CHECK: "tf.Add"
+ // CHECK-SAME: {_tpu_replicate = "x", device = "y"}
+ %2 = "tf.Add"(%arg2, %1) {_tpu_replicate = "x", device = "y"} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ // CHECK: return
+ return %2 : tensor<f32>
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
index dbb48dc..7dad109 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
@@ -276,6 +276,11 @@
// `_tpu_replicate` attribute.
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUClusterFormationPass();
+// Creates a pass that cleans up `_tpu_replicate` attribute on operations
+// that are inside a cluster.
+std::unique_ptr<OperationPass<ModuleOp>>
+CreateTPUClusterCleanupAttributesPass();
+
// Creates a pass that removes Identity/IdentityN ops from a cluster.
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUIdentityPruningPass();
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_cleanup_attributes.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_cleanup_attributes.cc
new file mode 100644
index 0000000..93098ac
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_cleanup_attributes.cc
@@ -0,0 +1,60 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "mlir/IR/Module.h" // from @llvm-project
+#include "mlir/Pass/PassManager.h" // from @llvm-project
+#include "mlir/Transforms/Passes.h" // from @llvm-project
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
+#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
+
+// This pass eliminate `_tpu_replicate` and `device` attribute on operations
+// that are contained in a tf_device.cluster op.
+
+namespace mlir {
+namespace TFTPU {
+
+namespace {
+
+constexpr char kTPUReplicateAttr[] = "_tpu_replicate";
+constexpr char kDeviceAttr[] = "device";
+
+class TPUCleanupClusterAttributesPass
+ : public PassWrapper<TPUCleanupClusterAttributesPass,
+ OperationPass<ModuleOp>> {
+ public:
+ void runOnOperation() override {
+ getOperation().walk([](tf_device::ClusterOp cluster) {
+ cluster.walk([](Operation *op) {
+ if (isa<tf_device::ClusterOp>(op)) return;
+ for (StringRef attr : {kTPUReplicateAttr, kDeviceAttr})
+ op->removeAttr(attr);
+ });
+ });
+ }
+};
+
+PassRegistration<TPUCleanupClusterAttributesPass> pass(
+ "tf-tpu-cleanup-cluster-attributes",
+ "Eliminate _tpu_replicate and other attributes from ops in a cluster");
+
+} // namespace
+
+std::unique_ptr<OperationPass<ModuleOp>>
+CreateTPUClusterCleanupAttributesPass() {
+ return std::make_unique<TPUCleanupClusterAttributesPass>();
+}
+
+} // namespace TFTPU
+} // namespace mlir