Introduce functional<->region conversion passes around extract outside compilation
- Follow functional->region transformation with a inlining pass to make sure calls
generated by the transform get inlined.
PiperOrigin-RevId: 324282111
Change-Id: Ifaacec3d8919f390fdeda8ca9af129d8e7dce086
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc
index cb1dd23..ed0528a 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc
@@ -82,15 +82,23 @@
// Run shape inference so that tf_executor/tf_device ops created later will
// likely to inherit more concrete types.
pm.addPass(TF::CreateTFShapeInferencePass());
- OpPassManager &func_pm = pm.nest<FuncOp>();
- func_pm.addPass(CreateTPUClusterFormationPass());
- // Place DecomposeResourceOpsPass before TFExecutorConstantSinking pass
- // because DecomposeResourceOpsPass uses pattern rewriter which hoists
- // changed constants out of tf_device.Launch.
- func_pm.addPass(TFDevice::CreateDecomposeResourceOpsPass());
- func_pm.addPass(CreateTPUHostComputationExpansionPass());
- pm.addNestedPass<FuncOp>(CreateTPUUpdateEmbeddingEnqueueOpInputsPass());
+ // Encode this in its own scope so that func_pm is not mistakenly used
+ // later on.
+ {
+ OpPassManager &func_pm = pm.nest<FuncOp>();
+ func_pm.addPass(CreateTPUClusterFormationPass());
+ // Place DecomposeResourceOpsPass before TFExecutorConstantSinking pass
+ // because DecomposeResourceOpsPass uses pattern rewriter which hoists
+ // changed constants out of tf_device.Launch.
+ func_pm.addPass(TFDevice::CreateDecomposeResourceOpsPass());
+ func_pm.addPass(CreateTPUHostComputationExpansionPass());
+ func_pm.addPass(CreateTPUUpdateEmbeddingEnqueueOpInputsPass());
+ }
+ pm.addPass(TF::CreateTFFunctionalControlFlowToRegions());
+ pm.addPass(mlir::createInlinerPass());
pm.addPass(CreateTPUExtractHeadTailOutsideCompilationPass());
+ pm.addPass(TF::CreateTFRegionControlFlowToFunctional());
+
// Run another shape inference pass because resource decomposition might have
// created new partial types.
pm.addPass(TF::CreateTFShapeInferencePass());