| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" |
| |
| #include "mlir/IR/Attributes.h" // from @llvm-project |
| #include "mlir/IR/Function.h" // from @llvm-project |
| #include "mlir/IR/Module.h" // from @llvm-project |
| #include "mlir/Pass/Pass.h" // from @llvm-project |
| #include "mlir/Pass/PassManager.h" // from @llvm-project |
| #include "mlir/Transforms/Passes.h" // from @llvm-project |
| #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" |
| #include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h" |
| #include "tensorflow/compiler/mlir/lite/transforms/passes.h" |
| #include "tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h" |
| #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" |
| #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" |
| #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" |
| |
| namespace mlir { |
| /// Create a pass to convert from the TFExecutor to the TF control dialect. |
| std::unique_ptr<OperationPass<FuncOp>> |
| CreateTFExecutorToControlDialectConversion(); |
| } // namespace mlir |
| |
| namespace tensorflow { |
| |
| void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs, |
| mlir::OpPassManager* pass_manager) { |
| pass_manager->addPass(mlir::TFL::CreatePrepareQuantizePass(quant_specs)); |
| if (quant_specs.default_ranges.first.hasValue() || |
| quant_specs.default_ranges.second.hasValue()) { |
| pass_manager->addPass(mlir::TFL::CreateDefaultQuantParamsPass( |
| quant_specs.default_ranges.first.getValueOr(0.0), |
| quant_specs.default_ranges.second.getValueOr(0.0), |
| quant_specs.IsSignedInferenceType())); |
| } |
| pass_manager->addPass(mlir::TFL::CreateQuantizePass()); |
| bool emit_quant_adaptor_ops = |
| quant_specs.inference_type != quant_specs.inference_input_type; |
| pass_manager->addPass( |
| mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops)); |
| } |
| |
| void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, |
| mlir::OpPassManager* pass_manager) { |
| mlir::TF::StandardPipelineOptions standard_pipeline_options; |
| standard_pipeline_options.enable_inliner = false; |
| standard_pipeline_options.form_clusters = pass_config.form_clusters; |
| mlir::TF::CreateTFStandardPipeline(*pass_manager, standard_pipeline_options); |
| pass_manager->addPass(mlir::TF::CreateDeviceIndexSelectorPass()); |
| |
| if (pass_config.shape_inference) { |
| pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass()); |
| } |
| // Keep this pass after the shape inference pass, which couldn't do shape |
| // inference for non-tf ops. |
| if (!pass_config.quant_specs.serialized_quant_stats.empty()) { |
| pass_manager->addPass( |
| mlir::quant::CreateImportQuantStatsPassForTFControlDialect( |
| pass_config.quant_specs.serialized_quant_stats)); |
| } |
| |
| pass_manager->addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions()); |
| |
| // The conversion pipeline has to follow the following orders: |
| // 1) Saved model related optimization like decompose resource ops |
| // 2) Convert composite functions like lstm/rnns, along with proper function |
| // inlining & dce. |
| // 3) Lower static tensor list pass. |
| |
| // This decomposes resource ops like ResourceGather into read-variable op |
| // followed by gather. This is used when the saved model import path is used |
| // during which resources dont get frozen in the python layer. |
| pass_manager->addNestedPass<mlir::FuncOp>( |
| mlir::TFDevice::CreateDecomposeResourceOpsPass()); |
| |
| // Note: |
| // We need to fuse composite ops before LowerStaticTensorList pass. |
| // The tensorflow list is not supported right now by that pass. |
| // Enable fusing composite ops that can be lowered to built-in TFLite ops. |
| if (pass_config.emit_builtin_tflite_ops) { |
| pass_manager->addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass()); |
| } |
| |
| pass_manager->addPass(mlir::createInlinerPass()); |
| pass_manager->addPass(mlir::createSymbolDCEPass()); |
| |
| if (pass_config.lower_tensor_list_ops) { |
| // TODO(haoliang): Add this pass by default. |
| pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass()); |
| } |
| |
| // This pass does resource analysis of saved model global tensors and marks |
| // those deemed read-only as immutable. |
| pass_manager->addPass( |
| mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass()); |
| |
| if (pass_config.shape_inference) { |
| // Add a shape inference pass to optimize away the unnecessary casts. |
| pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass()); |
| } |
| |
| pass_manager->addPass(mlir::TF::CreateTFRegionControlFlowToFunctional()); |
| |
| // Legalize while early to allow further constant folding. |
| // TODO(jpienaar): This may not actually matter as we do canonicalization |
| // after the legalize below, for now it needs to be below the above passes |
| // that work on TF dialect and before inliner so that the function calls in |
| // body and cond are inlined for optimization. |
| if (pass_config.legalize_tf_while) { |
| pass_manager->addPass(mlir::TFL::CreateLegalizeTFWhilePass()); |
| } |
| |
| // Add function inlining pass. Both TF and TFLite dialects are opted into |
| // function inliner interface. |
| pass_manager->addPass(mlir::createInlinerPass()); |
| |
| // TODO(jpienaar): Revise post dialect constants. |
| pass_manager->addPass(mlir::TF::CreateDecodeConstantPass()); |
| // Canonicalization includes const folding, which is utilized here to optimize |
| // away ops that can't get constant folded after PrepareTF pass. For example, |
| // tf.Conv2D is split into tf.Transpose and tfl.Conv2D. |
| pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass()); |
| pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCSEPass()); |
| // This pass does dead code elimination based on symbol visibility. |
| pass_manager->addPass(mlir::createSymbolDCEPass()); |
| // This pass 'freezes' immutable global tensors and inlines them as tf |
| // constant ops. |
| pass_manager->addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass()); |
| |
| // The below passes only make sense if Builtin TFLite ops are enabled |
| // for emission. |
| if (pass_config.emit_builtin_tflite_ops) { |
| // Prepare for TFLite dialect, rerun canonicalization, and then legalize to |
| // the TFLite dialect. |
| pass_manager->addPass( |
| mlir::TFL::CreatePrepareTFPass(pass_config.unfold_batch_matmul)); |
| pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass()); |
| if (pass_config.shape_inference) { |
| // Add a shape inference pass to optimize away the unnecessary casts. |
| pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass()); |
| } |
| pass_manager->addPass( |
| mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification)); |
| pass_manager->addPass(mlir::TFL::CreateOptimizePass()); |
| // This pass operates on TensorFlow ops but is triggered after legalization |
| // so that it can target constants introduced once TensorFlow Identity ops |
| // are removed during legalization. |
| pass_manager->addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass()); |
| pass_manager->addPass(mlir::createSymbolDCEPass()); |
| pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass()); |
| pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCSEPass()); |
| // This pass should be always at the end of the floating point model |
| // conversion. Some TFL ops like unidirectional |
| // sequence lstm will have stateful operands and some optimization passes |
| // will merge those operands if they have identical values & types. However, |
| // it's not desired by TFL. This pass serves as a "fix" pass to split the |
| // merged inputs until we have 1st class variable support or reuse |
| // tf.variable to model this. |
| pass_manager->addPass(mlir::TFL::CreateSplitMergedOperandsPass()); |
| |
| // Run quantization after all the floating point model conversion is |
| // completed. |
| if (pass_config.quant_specs.RunPropagationAndRewriteQuantizationPasses()) { |
| AddQuantizationPasses(pass_config.quant_specs, pass_manager); |
| } |
| } |
| } |
| |
| } // namespace tensorflow |
| |
| namespace mlir { |
| namespace TFL { |
| |
| struct StandardPipelineOptions |
| : public PassPipelineOptions<StandardPipelineOptions> { |
| // TODO(b/150915052): All the tf_tfl_translate_cl flags should |
| // move inside this. |
| }; |
| |
| // NOLINTNEXTLINE |
| // This creates the standard pass pipeline for TF->TFLite. This |
| // represents a std configuration for TFLite, for use with APIs like |
| // tensorflow/python/pywrap_mlir.py::experimental_run_pass_pipeline |
| // This does not yet include quantization passes. |
| void CreateTFLStandardPipeline(OpPassManager& pm, |
| const StandardPipelineOptions& options) { |
| OpPassManager& func_pm = pm.nest<FuncOp>(); |
| |
| // tf_executor dialect passes - Cleaning up the IR. |
| mlir::TF::StandardPipelineOptions standard_pipeline_options; |
| mlir::TF::CreateTFStandardPipeline(func_pm, standard_pipeline_options); |
| |
| // This is needed for control flow support with TF TensorList. |
| pm.addPass(mlir::TFL::CreateLowerStaticTensorListPass()); |
| |
| // Saved model pass to mark global tensors immutable. |
| pm.addPass(mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass()); |
| // Op fusion pass. |
| pm.addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass()); |
| |
| pm.addNestedPass<mlir::FuncOp>(mlir::TFL::CreateLegalizeTFWhilePass()); |
| |
| pm.addPass(mlir::createInlinerPass()); |
| |
| // Canonicalize, CSE etc. |
| pm.addPass(mlir::TF::CreateDecodeConstantPass()); |
| pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass()); |
| pm.addNestedPass<mlir::FuncOp>(mlir::createCSEPass()); |
| // DCE for private symbols. |
| pm.addPass(mlir::createSymbolDCEPass()); |
| |
| // freeze global tensors. |
| pm.addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass()); |
| |
| // TFLite dialect passes. |
| pm.addPass(mlir::TFL::CreatePrepareTFPass(true)); |
| pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass()); |
| pm.addPass( |
| mlir::TFL::CreateLegalizeTFPass(/*run_tfl_runtime_verification=*/true)); |
| pm.addPass(mlir::TFL::CreateOptimizePass()); |
| pm.addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass()); |
| pm.addPass(mlir::createSymbolDCEPass()); |
| |
| // Canonicalize, CSE etc. |
| pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass()); |
| pm.addNestedPass<mlir::FuncOp>(mlir::createCSEPass()); |
| |
| // Pass for stateful operands like LSTM. |
| pm.addPass(mlir::TFL::CreateSplitMergedOperandsPass()); |
| |
| pm.addPass(mlir::TFL::CreateWhileOutlinePass()); |
| |
| pm.addPass(mlir::TFL::CreateRuntimeVerifyPass()); |
| } |
| |
| // Registers a pass pipeline for the standard TFL passes. |
| static mlir::PassPipelineRegistration<StandardPipelineOptions> pipeline( |
| "tfl-standard-pipeline", |
| "Run the standard passes involved in transforming/optimizing the TF " |
| "program to TFLite after " |
| "importing into MLIR.", |
| CreateTFLStandardPipeline); |
| |
| } // namespace TFL |
| } // namespace mlir |