blob: 39e928989b370564781c324b30e0a98f88c20ba1 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
#include <memory>
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/core/framework/metrics.h"
namespace mlir {
namespace {
// Add logger to bridge passmanager.
// Enable timing statistics per pass for the bridge passmanager.
void EnableDetailedLogging(PassManager *pm) {
// Print the whole module after each pass, which requires disabling
// multi-threading as well.
pm->getContext()->disableMultithreading();
pm->enableIRPrinting(std::make_unique<tensorflow::BridgeLoggerConfig>(
/*print_module_scope=*/true));
pm->enableTiming();
}
} // namespace
namespace TFTPU {
namespace {
// Run the TF XLA Bridge based on the input pipeline, which can be either TPU
// bridge pipeline or non TPU bridge pipeline.
tensorflow::Status RunTFXLABridge(
ModuleOp module, bool enable_logging,
llvm::function_ref<void(OpPassManager &pm)> pipeline_builder) {
PassManager bridge(module.getContext());
::tensorflow::applyTensorflowAndCLOptions(bridge);
if (enable_logging || VLOG_IS_ON(1)) {
tensorflow::DumpMlirOpToFile("tf_xla_bridge_before", module);
if (VLOG_IS_ON(2)) EnableDetailedLogging(&bridge);
}
// Populate a passmanager with the list of passes that implement the bridge.
pipeline_builder(bridge);
// Add set of passes to lower back to graph (from tf_executor).
TF::AddGraphExportLoweringPasses(bridge);
mlir::StatusScopedDiagnosticHandler diag_handler(
module.getContext(), /*propagate=*/false,
/*filter_stack=*/!VLOG_IS_ON(1));
LogicalResult result = bridge.run(module);
(void)result;
if (enable_logging || VLOG_IS_ON(1))
tensorflow::DumpMlirOpToFile("tf_xla_bridge_after", module);
return diag_handler.ConsumeStatus();
}
void CreateTPUBridgePipelineImpl(OpPassManager &pm) {
// The following ops must be preserved regardless of reachability. Ideally,
// all graphs should have control dependencies to enforce this but this is
// currently not the case (see b/177478741).
const llvm::SmallVector<std::string, 4> ops_to_preserve = {
"tf.TPUReplicateMetadata", "tf.TPUCompilationResult",
"tf.TPUReplicatedOutput"};
pm.addNestedPass<func::FuncOp>(
tf_executor::CreateTFExecutorGraphPruningPass(ops_to_preserve));
// It is assumed at this stage there are no V1 control flow ops as Graph
// functionalization is ran before import. Ops can be lifted out of
// tf_executor dialect islands/graphs.
pm.addNestedPass<func::FuncOp>(
CreateExecutorDialectToFunctionalConversionPass());
// Guarantee all functions have one use, which enables more exact shape
// inference.
pm.addPass(mlir::TF::CreateGuaranteeAllFuncsOneUsePass());
// Run shape inference so that tf_executor/tf_device ops created later will
// likely to inherit more concrete types.
pm.addPass(TF::CreateTFShapeInferencePass());
pm.addNestedPass<func::FuncOp>(
CreateTPUReorderReplicateAndPartitionedInputsPass());
pm.addPass(CreateTPUClusterFormationPass());
// Run TPU cluster cleanup attributes so ops with no outside compiled
// attribute have no host device attribute.
pm.addPass(CreateTPUClusterCleanupAttributesPass());
pm.addPass(CreateOutsideCompiledToHostLaunchPass());
pm.addNestedPass<func::FuncOp>(TFDevice::CreateDeviceAttributeToLaunchPass());
// Running canonicalizer before decomposing resource ops in cluster helps the
// latter pass to converge faster as it does not have to spend time folding
// away dead ops.
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// Place DecomposeResourceOpsPass before TFExecutorConstantSinking pass
// because DecomposeResourceOpsPass uses pattern rewriter which hoists
// changed constants out of tf_device.Launch.
pm.addPass(TFDevice::CreateDecomposeResourceOpsInClusterPass());
// Encode this in its own scope so that func_pm is not mistakenly used
// later on.
{
OpPassManager &func_pm = pm.nest<func::FuncOp>();
func_pm.addPass(CreateTPUHostComputationExpansionPass());
func_pm.addPass(CreateTPUUpdateEmbeddingEnqueueOpInputsPass());
}
// TODO(b/173622615): This should incrementally be moved down as
// more passes support this representation and then can be removed once
// all passes support it.
pm.addPass(TFDevice::CreateHostLaunchToOutsideCompiledPass());
// TODO(b/173622615): Once OutsideCompilation is represented by launch op and
// the remaining passes including Inliner support it, remove this
// LaunchToDeviceAttributePass. This LaunchToDeviceAttribute pass needs to
// come before TPUClusterCleanupAttributes pass or else the device attribute
// will be removed from launch causing an error.
pm.addNestedPass<func::FuncOp>(TFDevice::CreateLaunchToDeviceAttributePass());
// TODO(b/173622615): This can be removed once more passes support outside
// compilation represented by op and conversion back to attribute is removed.
pm.addPass(CreateOutsideCompiledToHostLaunchPass());
// Note that the region-based control-flow produced here still contains
// function call ops which get inlined by the subsequent inliner pass.
pm.addPass(TF::CreateTFFunctionalControlFlowToRegions());
pm.addPass(mlir::createInlinerPass());
pm.addNestedPass<func::FuncOp>(
TF::CreateDropWhileShapeInvariantInDeviceClusterPass());
// Run another shape inference pass because resource decomposition might have
// created new partial types. Also, after dropping `shape_invariant` attribute
// from While/WhileRegion ops within cluster would lead to more precise
// shapes.
pm.addPass(TF::CreateTFShapeInferencePass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addPass(CreateTPUClusterCleanupAttributesPass());
pm.addPass(TFDevice::CreateResourceOpLiftingPass());
// Re-run the canonicalizer pass as some cleanup during resource op lifting
// pass opens up some opportunities for canonicalization of cluster ops.
// Specifically, we want to eliminate pass through results from the cluster
// op.
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// TODO(b/173622615): This should incrementally be moved down as
// more passes support this representation and then can be removed once
// all passes support it.
pm.addPass(TFDevice::CreateHostLaunchToOutsideCompiledPass());
pm.addNestedPass<func::FuncOp>(createCSEPass());
if (tensorflow::GetMlirCommonFlags()
->tf_mlir_enable_merge_control_flow_pass) {
pm.addPass(TFDevice::CreateMergeControlFlowPass());
}
pm.addPass(TFDevice::CreateMarkOpsForOutsideCompilationPass());
pm.addPass(CreateTPUExtractHeadTailOutsideCompilationPass());
pm.addPass(CreateTPUExtractOutsideCompilationPass());
pm.addNestedPass<func::FuncOp>(TFDevice::CreateClusterConstantSinkingPass());
pm.addPass(TF::CreateResourceDeviceInferencePass());
pm.addPass(TFDevice::CreateClusterOutliningPass());
pm.addPass(CreateTPUResourceReadForWritePass());
pm.addPass(TFDevice::CreateMarkInputOutputAliasesPass());
pm.addPass(CreateTPUShardingIdentificationPass());
pm.addNestedPass<func::FuncOp>(
CreateTPUResourceReadsWritesPartitioningPass());
pm.addPass(TFDevice::CreateAnnotateParameterReplicationPass());
pm.addPass(CreateTPURewritePass());
pm.addPass(createSymbolDCEPass());
pm.addNestedPass<func::FuncOp>(
TFDevice::CreateReplicateInvariantOpHoistingPass());
pm.addPass(CreateTPUMergeVariablesWithExecutePass());
pm.addNestedPass<func::FuncOp>(
TF::CreateHoistReplicateInvariantResourceWritesPass());
pm.addNestedPass<func::FuncOp>(CreateTPUColocateCompositeResourceOps());
pm.addPass(CreateTPUVariableRuntimeReformattingPass());
pm.addPass(TF::CreateTFRegionControlFlowToFunctional());
}
} // namespace
void CreateTPUBridgePipeline(OpPassManager &pm) {
pm.addNestedPass<func::FuncOp>(
CreateCanonicalizeCompileAndReplicateAttributesPass());
CreateTPUBridgePipelineImpl(pm);
}
void CreateTPUBridgePipelineV1(OpPassManager &pm) {
// Convert to unified compilation and replication attributes.
pm.addNestedPass<func::FuncOp>(
CreateCanonicalizeCompileAndReplicateAttributesPass());
// Guarantee all functions have one use, which enables more exact shape
// inference.
pm.addPass(mlir::TF::CreateGuaranteeAllFuncsOneUsePass());
pm.addPass(TF::CreateTFShapeInferencePass());
// For V1 compatibility, we process a module where the graph does not have
// feeds and fetched. We extract first the TPU computation in a submodule,
// where it'll be in a function with args and returned values, much more like
// a TF v2 module. We can then run the usual pipeline on this nested module.
// Afterward we inline back in the parent module and delete the nested one.
pm.addPass(tf_executor::CreateTFExecutorTPUV1IslandCoarseningPass());
pm.addPass(tf_executor::CreateTFExecutorTPUV1IslandOutliningPass());
OpPassManager &nested_module = pm.nest<ModuleOp>();
CreateTPUBridgePipelineImpl(nested_module);
pm.addPass(tf_executor::CreateTFExecutorTPUV1IslandInliningPass());
// There are cases where we don't consume all compilation and replication
// attributes like we do for the V2 pipeline, so we need to convert them from
// unified to legacy attributes before they get exposed to outside of the
// bridge.
pm.addNestedPass<func::FuncOp>(
CreateConvertToLegacyCompileAndReplicateAttributesPass());
}
tensorflow::Status TPUBridge(ModuleOp module, bool enable_logging,
bool fallback_enabled) {
Status status =
RunTFXLABridge(module, enable_logging, CreateTPUBridgePipeline);
tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter(
"tpu", "v2", fallback_enabled,
status == Status::OK() ? "success" : "failure");
return status;
}
tensorflow::Status TPUBridgeV1Compat(ModuleOp module, bool enable_logging,
bool fallback_enabled) {
Status status =
RunTFXLABridge(module, enable_logging, CreateTPUBridgePipelineV1);
tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter(
"tpu", "v1", fallback_enabled,
status == Status::OK() ? "success" : "failure");
return status;
}
} // namespace TFTPU
namespace TF {
void AddGraphExportLoweringPasses(OpPassManager &pm) {
auto add_pass = [&](std::unique_ptr<Pass> pass) {
pm.addNestedPass<func::FuncOp>(std::move(pass));
pm.addPass(CreateBreakUpIslandsPass());
};
add_pass(CreateFunctionalToExecutorDialectConversionPass());
add_pass(TFDevice::CreateReplicateToIslandPass());
add_pass(TFDevice::CreateReplicaIDToDeviceOrdinalPass());
add_pass(TFDevice::CreateParallelExecuteToIslandsPass());
add_pass(TFDevice::CreateLaunchToDeviceAttributePass());
pm.addNestedPass<func::FuncOp>(TFTPU::CreateTPUDevicePropagationPass());
pm.addPass(createSymbolDCEPass());
if (tensorflow::GetMlirCommonFlags()
->tf_mlir_enable_convert_control_to_data_outputs_pass) {
pm.addPass(tf_executor::CreateTFExecutorConvertControlToDataOutputsPass());
}
pm.addPass(CreateVerifySuitableForExportPass());
}
tensorflow::Status RunBridgeWithStandardPipeline(ModuleOp module,
bool enable_logging,
bool enable_inliner) {
PassManager bridge(module.getContext());
if (enable_logging || VLOG_IS_ON(1)) {
tensorflow::DumpMlirOpToFile("standard_pipeline_before", module);
if (VLOG_IS_ON(2)) EnableDetailedLogging(&bridge);
}
StandardPipelineOptions pipeline_options;
pipeline_options.enable_inliner.setValue(enable_inliner);
CreateTFStandardPipeline(bridge, pipeline_options);
mlir::StatusScopedDiagnosticHandler diag_handler(
module.getContext(), /*propagate=*/false,
/*filter_stack=*/!VLOG_IS_ON(1));
LogicalResult result = bridge.run(module);
(void)result;
if (enable_logging || VLOG_IS_ON(1))
tensorflow::DumpMlirOpToFile("standard_pipeline_after", module);
return diag_handler.ConsumeStatus();
}
namespace {
void CreateTFXLABridgePipeline(OpPassManager &pm) {
// The following ops must be preserved regardless of reachability. Ideally,
// all graphs should have control dependencies to enforce this.
VLOG(2) << "Create TF XLA Bridge pipeline";
const llvm::SmallVector<std::string, 4> ops_to_preserve = {};
pm.addNestedPass<func::FuncOp>(
tf_executor::CreateTFExecutorGraphPruningPass(ops_to_preserve));
// It is assumed at this stage there are no V1 control flow ops as Graph
// functionalization is ran before import. Ops can be lifted out of
// tf_executor dialect islands/graphs.
pm.addNestedPass<func::FuncOp>(
CreateExecutorDialectToFunctionalConversionPass());
// Guarantee all functions have one use, which enables more exact shape
// inference.
pm.addPass(TF::CreateTFShapeInferencePass());
// Running canonicalizer before decomposing resource ops in cluster helps the
// latter pass to converge faster as it does not have to spend time folding
// away dead ops.
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// Encapsulate StatefulPartitionedCallOp within a cluster so that the
// composite resource ops can be decomposed.
pm.addPass(TFDevice::CreateXlaClusterFormationPass());
// Place DecomposeResourceOpsPass.
pm.addPass(TFDevice::CreateDecomposeResourceOpsInClusterPass());
// Run another shape inference pass because resource decomposition might have
// created new partial types. Also, after dropping `shape_invariant` attribute
// from While/WhileRegion ops within cluster would lead to more precise
// shapes.
pm.addPass(TF::CreateTFShapeInferencePass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addPass(TFDevice::CreateResourceOpLiftingPass());
// Inline the StatefulPartitionedCallOp op based in the parent region.
pm.addPass(TFDevice::CreateXlaInlineDeviceOpsPass());
// Re-run the canonicalizer pass as some cleanup during resource op lifting
// pass opens up some opportunities for canonicalization of cluster ops.
// Specifically, we want to eliminate pass through results from the cluster
// op.
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createCSEPass());
pm.addPass(createSymbolDCEPass());
pm.addPass(TF::CreateTFRegionControlFlowToFunctional());
}
} // namespace
tensorflow::Status RunTFXLABridge(ModuleOp module, bool enable_logging) {
Status status = mlir::TFTPU::RunTFXLABridge(module, enable_logging,
CreateTFXLABridgePipeline);
tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter(
/*device type*/ "cpu/gpu", /*bridge version*/ "tfxla",
/*fallback_enabled*/ false,
/*result*/ status == Status::OK() ? "success" : "failure");
return status;
}
} // namespace TF
} // namespace mlir