Add a partial-conversion TF->XLA lowering, a graph pruning and a canonicalization pass to TF->XLA pipeline

PiperOrigin-RevId: 285548949
Change-Id: Ie6ca4112043a4d314b471dd0adf9172b6f612de0
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
index 8f761ac..4e914a5 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
@@ -26,6 +26,7 @@
 #include "mlir/Pass/Pass.h"  // TF:local_config_mlir
 #include "mlir/Pass/PassManager.h"  // TF:local_config_mlir
 #include "mlir/Transforms/Passes.h"  // TF:local_config_mlir
+#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
@@ -210,7 +211,16 @@
   mlir::PassManager tf2xla(module_op.getContext());
   tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
   tf2xla.addPass(mlir::xla_hlo::createLegalizeTFControlFlowPass());
-  tf2xla.addNestedPass<mlir::FuncOp>(mlir::xla_hlo::createLegalizeTFPass());
+  // We need to run LegalizeTFPass 2 times because first
+  // LegalizeTFPass(allow_partial_conversion=true) can expose more graph pruning
+  // and canonicalization opportunities that are necessary for the second
+  // LegalizeTFPass(allow_partial_conversion=false) invocation.
+  tf2xla.addNestedPass<mlir::FuncOp>(mlir::xla_hlo::createLegalizeTFPass(true));
+  tf2xla.addPass(mlir::tf_executor::CreateTFExecutorGraphPruningPass(
+      /*skip_main_func=*/true));
+  tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
+  tf2xla.addNestedPass<mlir::FuncOp>(
+      mlir::xla_hlo::createLegalizeTFPass(false));
 
   {
     // Make sure we catch any error reported by MLIR and forward it to the TF
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc
index 1668cf6..b007687 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc
@@ -141,15 +141,14 @@
   auto status_or_hlo_module = xla::HloModule::CreateFromProto(
       compilation_result.computation->proto(), module_config);
   ASSERT_TRUE(status_or_hlo_module.ok());
-  string expected_hlo_module_string = R"(HloModule main.7
+  string expected_hlo_module_string = R"(HloModule main.6
 
-ENTRY %main.7 (arg_tuple.1: (f32[10,19], f32[19,10])) -> (f32[10,19]) {
+ENTRY %main.6 (arg_tuple.1: (f32[10,19], f32[19,10])) -> (f32[10,19]) {
   %arg_tuple.1 = (f32[10,19]{1,0}, f32[19,10]{1,0}) parameter(0)
   %get-tuple-element.2 = f32[10,19]{1,0} get-tuple-element((f32[10,19]{1,0}, f32[19,10]{1,0}) %arg_tuple.1), index=0
-  %constant.4 = s64[2]{0} constant({10, 19})
   %get-tuple-element.3 = f32[19,10]{1,0} get-tuple-element((f32[10,19]{1,0}, f32[19,10]{1,0}) %arg_tuple.1), index=1
-  %reshape.5 = f32[10,19]{1,0} reshape(f32[19,10]{1,0} %get-tuple-element.3)
-  ROOT %tuple.6 = (f32[10,19]{1,0}) tuple(f32[10,19]{1,0} %reshape.5)
+  %reshape.4 = f32[10,19]{1,0} reshape(f32[19,10]{1,0} %get-tuple-element.3)
+  ROOT %tuple.5 = (f32[10,19]{1,0}) tuple(f32[10,19]{1,0} %reshape.4)
 }
 
 )";