Add a partial-conversion TF->XLA lowering, a graph pruning and a canonicalization pass to TF->XLA pipeline
PiperOrigin-RevId: 285548949
Change-Id: Ie6ca4112043a4d314b471dd0adf9172b6f612de0
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
index 8f761ac..4e914a5 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
@@ -26,6 +26,7 @@
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Pass/PassManager.h" // TF:local_config_mlir
#include "mlir/Transforms/Passes.h" // TF:local_config_mlir
+#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
@@ -210,7 +211,16 @@
mlir::PassManager tf2xla(module_op.getContext());
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
tf2xla.addPass(mlir::xla_hlo::createLegalizeTFControlFlowPass());
- tf2xla.addNestedPass<mlir::FuncOp>(mlir::xla_hlo::createLegalizeTFPass());
+ // We need to run LegalizeTFPass 2 times because first
+ // LegalizeTFPass(allow_partial_conversion=true) can expose more graph pruning
+ // and canonicalization opportunities that are necessary for the second
+ // LegalizeTFPass(allow_partial_conversion=false) invocation.
+ tf2xla.addNestedPass<mlir::FuncOp>(mlir::xla_hlo::createLegalizeTFPass(true));
+ tf2xla.addPass(mlir::tf_executor::CreateTFExecutorGraphPruningPass(
+ /*skip_main_func=*/true));
+ tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
+ tf2xla.addNestedPass<mlir::FuncOp>(
+ mlir::xla_hlo::createLegalizeTFPass(false));
{
// Make sure we catch any error reported by MLIR and forward it to the TF
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc
index 1668cf6..b007687 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc
@@ -141,15 +141,14 @@
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
compilation_result.computation->proto(), module_config);
ASSERT_TRUE(status_or_hlo_module.ok());
- string expected_hlo_module_string = R"(HloModule main.7
+ string expected_hlo_module_string = R"(HloModule main.6
-ENTRY %main.7 (arg_tuple.1: (f32[10,19], f32[19,10])) -> (f32[10,19]) {
+ENTRY %main.6 (arg_tuple.1: (f32[10,19], f32[19,10])) -> (f32[10,19]) {
%arg_tuple.1 = (f32[10,19]{1,0}, f32[19,10]{1,0}) parameter(0)
%get-tuple-element.2 = f32[10,19]{1,0} get-tuple-element((f32[10,19]{1,0}, f32[19,10]{1,0}) %arg_tuple.1), index=0
- %constant.4 = s64[2]{0} constant({10, 19})
%get-tuple-element.3 = f32[19,10]{1,0} get-tuple-element((f32[10,19]{1,0}, f32[19,10]{1,0}) %arg_tuple.1), index=1
- %reshape.5 = f32[10,19]{1,0} reshape(f32[19,10]{1,0} %get-tuple-element.3)
- ROOT %tuple.6 = (f32[10,19]{1,0}) tuple(f32[10,19]{1,0} %reshape.5)
+ %reshape.4 = f32[10,19]{1,0} reshape(f32[19,10]{1,0} %get-tuple-element.3)
+ ROOT %tuple.5 = (f32[10,19]{1,0}) tuple(f32[10,19]{1,0} %reshape.4)
}
)";