Replace logging in MLIR bridge passes with PassInstrumentation that handles logging via PassManager.

PiperOrigin-RevId: 280675323
Change-Id: I5d3406bd7b55b126b2e939b9589a041998ba5dcb
diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD
index 4706c3e..978ec60 100644
--- a/tensorflow/compiler/mlir/tensorflow/BUILD
+++ b/tensorflow/compiler/mlir/tensorflow/BUILD
@@ -230,10 +230,10 @@
     ],
     includes = ["include"],
     deps = [
+        ":bridge_logger",
         ":convert_tensor",
         ":convert_type",
         ":device_util",
-        ":dump_mlir_util",
         ":error_util",
         ":export_tf_dialect_op",
         ":mangling_util",
@@ -949,3 +949,15 @@
         "@local_config_mlir//:IR",
     ],
 )
+
+cc_library(
+    name = "bridge_logger",
+    srcs = ["utils/bridge_logger.cc"],
+    hdrs = ["utils/bridge_logger.h"],
+    deps = [
+        ":dump_mlir_util",
+        "@llvm//:support",
+        "@local_config_mlir//:IR",
+        "@local_config_mlir//:Pass",
+    ],
+)
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc
index 695c1bf..d955eed 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc
@@ -15,9 +15,12 @@
 
 #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
 
+#include <memory>
+
 #include "mlir/Pass/PassManager.h"  // TF:local_config_mlir
 #include "mlir/Transforms/Passes.h"  // TF:local_config_mlir
 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
+#include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
 
 namespace mlir {
@@ -41,9 +44,14 @@
   pm.addNestedPass<FuncOp>(createCanonicalizerPass());
 }
 
-tensorflow::Status TPUBridge(ModuleOp module) {
-  // Populate a passmanager with the list of passes that implement the bridge.
+tensorflow::Status TPUBridge(ModuleOp module, bool enable_logging) {
   PassManager bridge(module.getContext());
+
+  // Add logger to bridge passmanager.
+  if (enable_logging)
+    bridge.addInstrumentation(std::make_unique<tensorflow::BridgeLogger>());
+
+  // Populate a passmanager with the list of passes that implement the bridge.
   createTPUBridge(bridge);
 
   // Run the bridge on the module, in case of failure, the `diag_handler`
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h
index 2a3d380..6b55b0c 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h
@@ -24,7 +24,7 @@
 
 // Run all the passes involved in transforming the graph before execution so
 // that it is suitable for targeting TPUs.
-tensorflow::Status TPUBridge(ModuleOp module);
+tensorflow::Status TPUBridge(ModuleOp module, bool enable_logging);
 
 }  // namespace TFTPU
 }  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc
index f4a2b62..7dab061 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc
@@ -29,8 +29,6 @@
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
-#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
-#include "tensorflow/core/platform/logging.h"
 
 namespace mlir {
 namespace TFDevice {
@@ -125,20 +123,12 @@
 }
 
 void ClusterOutliningPass::runOnModule() {
-  if (VLOG_IS_ON(1))
-    tensorflow::DumpMlirOpToFile("mlir_device_cluster_outlining_before",
-                                 getModule());
-
   ModuleOp m = getModule();
   ModuleManager module_manager(m);
   OpBuilder builder(m.getContext());
   m.walk([&](tf_device::LaunchOp launch) {
     OutlineLaunch(launch, &module_manager, &builder);
   });
-
-  if (VLOG_IS_ON(1))
-    tensorflow::DumpMlirOpToFile("mlir_device_cluster_outlining_after",
-                                 getModule());
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc
index b2cae78..c6958d9 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc
@@ -35,7 +35,6 @@
 #include "mlir/Pass/PassRegistry.h"  // TF:local_config_mlir
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
-#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
 #include "tensorflow/core/platform/logging.h"
 
 namespace mlir {
@@ -321,10 +320,6 @@
 }
 
 void ExecutorIslandCoarsening::runOnFunction() {
-  if (VLOG_IS_ON(1))
-    tensorflow::DumpMlirOpToFile("mlir_executor_island_coarsening_before",
-                                 getFunction());
-
   getFunction().walk([](GraphOp graph) {
     InsertDummyIslandForFetch(graph.GetFetch());
 
@@ -348,10 +343,6 @@
       }
     } while (updated);
   });
-
-  if (VLOG_IS_ON(1))
-    tensorflow::DumpMlirOpToFile("mlir_executor_island_coarsening_after",
-                                 getFunction());
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc
index 87d0c90..36f6f3a 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc
@@ -27,8 +27,6 @@
 #include "mlir/Support/LogicalResult.h"  // TF:local_config_mlir
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
-#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
-#include "tensorflow/core/platform/logging.h"
 
 namespace mlir {
 namespace TFDevice {
@@ -148,16 +146,8 @@
 }
 
 void ReplicateInvariantOpHoistingPass::runOnFunction() {
-  if (VLOG_IS_ON(1))
-    tensorflow::DumpMlirOpToFile("mlir_replicate_invariant_op_hoisting_before",
-                                 getFunction());
-
   getFunction().walk(
       [](tf_device::ReplicateOp op) { HoistReplicateInvariantOps(op); });
-
-  if (VLOG_IS_ON(1))
-    tensorflow::DumpMlirOpToFile("mlir_replicate_invariant_op_hoisting_after",
-                                 getFunction());
 }
 }  // anonymous namespace
 
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc
index 6a4bba7..8033773 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc
@@ -33,8 +33,6 @@
 #include "mlir/Pass/Pass.h"  // TF:local_config_mlir
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
-#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
-#include "tensorflow/core/platform/logging.h"
 
 namespace mlir {
 namespace TFDevice {
@@ -198,10 +196,6 @@
 }
 
 void ReplicateToIslandPass::runOnFunction() {
-  if (VLOG_IS_ON(1))
-    tensorflow::DumpMlirOpToFile(
-        "mlir_device_replicate_to_executor_island_before", getFunction());
-
   const Dialect* tf_dialect = getContext().getRegisteredDialect("tf");
   if (!tf_dialect) {
     signalPassFailure();
@@ -211,10 +205,6 @@
   getFunction().walk([&](tf_executor::IslandOp island_op) {
     LowerSingleIslandReplicateToIslands(tf_dialect, island_op);
   });
-
-  if (VLOG_IS_ON(1))
-    tensorflow::DumpMlirOpToFile(
-        "mlir_device_replicate_to_executor_island_after", getFunction());
 }
 }  // anonymous namespace
 
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc
index 7e59a8c..507a96e 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc
@@ -30,9 +30,7 @@
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
-#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
-#include "tensorflow/core/platform/logging.h"
 
 namespace mlir {
 namespace TFDevice {
@@ -377,17 +375,7 @@
   });
 }
 
-void ResourceOpLiftingPass::runOnFunction() {
-  if (VLOG_IS_ON(1))
-    tensorflow::DumpMlirOpToFile("mlir_resource_op_lifting_before",
-                                 getFunction());
-
-  LiftResourceOps(getFunction());
-
-  if (VLOG_IS_ON(1))
-    tensorflow::DumpMlirOpToFile("mlir_resource_op_lifting_after",
-                                 getFunction());
-}
+void ResourceOpLiftingPass::runOnFunction() { LiftResourceOps(getFunction()); }
 
 std::unique_ptr<OpPassBase<FuncOp>> CreateResourceOpLiftingPass() {
   return std::make_unique<ResourceOpLiftingPass>();
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc b/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc
index 6e6b2d4..e4358e7 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc
@@ -28,9 +28,7 @@
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
-#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
-#include "tensorflow/core/platform/logging.h"
 
 #define DEBUG_TYPE "tf-executor-sink-constant"
 
@@ -43,10 +41,6 @@
 class ExecutorConstantSinking
     : public mlir::FunctionPass<ExecutorConstantSinking> {
   void runOnFunction() override {
-    if (VLOG_IS_ON(1))
-      tensorflow::DumpMlirOpToFile("mlir_device_constant_sinking_before",
-                                   getFunction());
-
     getFunction().walk([](tf_device::LaunchOp launch) {
       LLVM_DEBUG(llvm::dbgs() << "Visit " << *launch.getOperation() << "\n");
       // For each launch op, we find the values used that come from a constant
@@ -86,10 +80,6 @@
                                 << "\n     in " << *use->get() << "\n");
       });
     });
-
-    if (VLOG_IS_ON(1))
-      tensorflow::DumpMlirOpToFile("mlir_device_constant_sinking_after",
-                                   getFunction());
   }
 };
 
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc
index 0e52b07..6580ad5 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc
@@ -47,8 +47,6 @@
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
-#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
-#include "tensorflow/core/platform/logging.h"
 
 namespace mlir {
 namespace TFTPU {
@@ -410,10 +408,6 @@
 }
 
 void TPUClusterFormation::runOnFunction() {
-  if (VLOG_IS_ON(1))
-    tensorflow::DumpMlirOpToFile("mlir_tpu_cluster_formation_before",
-                                 getFunction());
-
   MetadataMap metadata_map;
   if (failed(CollectMetadata(getFunction(), &metadata_map)))
     return signalPassFailure();
@@ -456,10 +450,6 @@
   });
 
   if (remove_result.wasInterrupted()) return signalPassFailure();
-
-  if (VLOG_IS_ON(1))
-    tensorflow::DumpMlirOpToFile("mlir_tpu_cluster_formation_after",
-                                 getFunction());
 }
 }  // anonymous namespace
 
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc
index bf3c1af..9f71bb7 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc
@@ -41,7 +41,6 @@
 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
-#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
 #include "tensorflow/compiler/xla/xla.pb.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -49,7 +48,6 @@
 #include "tensorflow/core/framework/tensor_shape.pb.h"
 #include "tensorflow/core/framework/types.pb.h"
 #include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
 #include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h"
 #include "tensorflow/core/util/device_name_utils.h"
@@ -508,9 +506,6 @@
 }
 
 void TPURewritePass::runOnModule() {
-  if (VLOG_IS_ON(1))
-    tensorflow::DumpMlirOpToFile("mlir_tpu_rewrite_before", getModule());
-
   llvm::SmallVector<tensorflow::DeviceNameUtils::ParsedName, 8> devices;
   if (failed(tensorflow::GetDevicesFromOp(getModule(), &devices)))
     return signalPassFailure();
@@ -528,9 +523,6 @@
   getModule().walk([&](TF::TPUCompilationResultOp op) { op.erase(); });
 
   // TODO(b/139377366): Remove functions that are no longer needed.
-
-  if (VLOG_IS_ON(1))
-    tensorflow::DumpMlirOpToFile("mlir_tpu_rewrite_after", getModule());
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc
index 16b0d6f..22d04b2 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc
@@ -22,8 +22,6 @@
 #include "mlir/Pass/PassRegistry.h"  // TF:local_config_mlir
 #include "mlir/Support/STLExtras.h"  // TF:local_config_mlir
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
-#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
-#include "tensorflow/core/platform/logging.h"
 
 // This pass is used in preparation for Graph export.
 // The GraphDef exporter expects each op to be in its own island.
@@ -47,10 +45,6 @@
 }  // end anonymous namespace
 
 void BreakUpIslands::runOnOperation() {
-  if (VLOG_IS_ON(1))
-    tensorflow::DumpMlirOpToFile("mlir_executor_breakup_islands_before",
-                                 getOperation());
-
   auto graph_op_range = getOperation().getBody().front().without_terminator();
   tf_executor::GraphOp graph_op;
   if (graph_op_range.begin() != graph_op_range.end() &&
@@ -110,10 +104,6 @@
     new_op->setAttrs(item.getAttrList());
     item.erase();
   }
-
-  if (VLOG_IS_ON(1))
-    tensorflow::DumpMlirOpToFile("mlir_executor_breakup_islands_after",
-                                 getOperation());
 }
 
 // Converts a single island into multiple islands (one for each op). The islands
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_functional_to_executor.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_functional_to_executor.cc
index 6c5cdce..ff397e4 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/tf_functional_to_executor.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_functional_to_executor.cc
@@ -20,8 +20,6 @@
 #include "mlir/Pass/Pass.h"  // TF:local_config_mlir
 #include "mlir/Pass/PassRegistry.h"  // TF:local_config_mlir
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
-#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
-#include "tensorflow/core/platform/logging.h"
 
 #define DEBUG_TYPE "tf-functional-to-executor"
 
@@ -48,10 +46,6 @@
 }  // end anonymous namespace
 
 void FunctionalToExecutorDialectConversion::runOnFunction() {
-  if (VLOG_IS_ON(1))
-    tensorflow::DumpMlirOpToFile("mlir_functional_to_executor_before",
-                                 getFunction());
-
   if (getFunction().getBlocks().size() != 1) {
     LLVM_DEBUG(llvm::dbgs() << "Expect single block function, skip conversion "
                                "to tf_executor dialect\n");
@@ -101,10 +95,6 @@
   for (auto item : llvm::enumerate(graph_op.getResults())) {
     return_op.setOperand(item.index(), item.value());
   }
-
-  if (VLOG_IS_ON(1))
-    tensorflow::DumpMlirOpToFile("mlir_functional_to_executor_after",
-                                 getFunction());
 }
 
 std::unique_ptr<OpPassBase<FuncOp>>
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc
new file mode 100644
index 0000000..a37e092
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc
@@ -0,0 +1,42 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h"
+
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "mlir/IR/Operation.h"  // TF:local_config_mlir
+#include "mlir/Pass/Pass.h"  // TF:local_config_mlir
+#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
+
+namespace tensorflow {
+
+// Logs op to file with name of format `mlir_bridge-pass_name-file_suffix.mlir`.
+inline static void Log(mlir::Pass* pass, mlir::Operation* op,
+                       llvm::StringRef file_suffix) {
+  DumpMlirOpToFile(
+      llvm::formatv("mlir_bridge-{0}-{1}", pass->getName(), file_suffix).str(),
+      op);
+}
+
+void BridgeLogger::runBeforePass(mlir::Pass* pass, mlir::Operation* op) {
+  Log(pass, op, "before");
+}
+
+void BridgeLogger::runAfterPass(mlir::Pass* pass, mlir::Operation* op) {
+  Log(pass, op, "after");
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h
new file mode 100644
index 0000000..2943a378
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h
@@ -0,0 +1,35 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_BRIDGE_LOGGER_H_
+#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_BRIDGE_LOGGER_H_
+
+#include "mlir/IR/Operation.h"  // TF:local_config_mlir
+#include "mlir/Pass/Pass.h"  // TF:local_config_mlir
+#include "mlir/Pass/PassInstrumentation.h"  // TF:local_config_mlir
+
+namespace tensorflow {
+
+// Logger for logging/dumping MLIR modules before and after passes in bridge
+// targeting TPUs.
+class BridgeLogger : public mlir::PassInstrumentation {
+ public:
+  void runBeforePass(mlir::Pass* pass, mlir::Operation* op) override;
+  void runAfterPass(mlir::Pass* pass, mlir::Operation* op) override;
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_BRIDGE_LOGGER_H_
diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc
index deb1b84..ff7fb2b 100644
--- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc
+++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc
@@ -106,7 +106,8 @@
   if (VLOG_IS_ON(1)) DumpModule(*module, "mlir_bridge_before_");
 
   // Run the bridge now
-  TF_RETURN_IF_ERROR(mlir::TFTPU::TPUBridge(*module));
+  TF_RETURN_IF_ERROR(
+      mlir::TFTPU::TPUBridge(*module, /*enable_logging=*/VLOG_IS_ON(1)));
 
   if (VLOG_IS_ON(1)) DumpModule(*module, "mlir_bridge_after_");