blob: 358db8fa099754e2e28758e9bb85331b8675332d [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Pass/PassManager.h" // TF:llvm-project
#include "mlir/Transforms/Passes.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
namespace mlir {
/// Create a pass to convert from the TFExecutor to the TF control dialect.
std::unique_ptr<OpPassBase<FuncOp>>
CreateTFExecutorToControlDialectConversion();
} // namespace mlir
namespace tensorflow {
void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
mlir::OpPassManager* pass_manager) {
pass_manager->addPass(mlir::TFL::CreatePrepareQuantizePass(quant_specs));
pass_manager->addPass(mlir::TFL::CreateQuantizePass());
bool emit_quant_adaptor_ops =
quant_specs.inference_type != quant_specs.inference_input_type;
pass_manager->addPass(
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
if (quant_specs.default_ranges.first.hasValue() ||
quant_specs.default_ranges.second.hasValue()) {
pass_manager->addPass(mlir::TFL::CreateDefaultQuantParamsPass(
quant_specs.default_ranges.first.getValueOr(0.0),
quant_specs.default_ranges.second.getValueOr(0.0)));
pass_manager->addPass(mlir::TFL::CreateQuantizePass());
pass_manager->addPass(
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
}
}
void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
mlir::OpPassManager* pass_manager) {
pass_manager->addPass(mlir::tf_executor::CreateSwitchFoldPass());
if (pass_config.skip_control_dialect) {
// Merge islands.
pass_manager->addPass(
mlir::tf_executor::CreateTFExecutorIslandCoarseningPass());
// Assuming island coarsening above results in a graph with a single island,
// a canonicalization can be ran to hoist the ops of the single island out.
pass_manager->addPass(mlir::createCanonicalizerPass());
if (pass_config.form_clusters)
pass_manager->addPass(mlir::TFDevice::CreateClusterFormationPass());
} else {
pass_manager->addPass(mlir::CreateTFExecutorToControlDialectConversion());
pass_manager->addPass(mlir::TFControlFlow::CreateRaiseTFControlFlowPass());
}
if (!pass_config.quant_specs.serialized_quant_stats.empty()) {
pass_manager->addPass(
mlir::quant::CreateImportQuantStatsPassForTFControlDialect(
pass_config.quant_specs.serialized_quant_stats));
}
if (pass_config.lower_tensor_list_ops) {
// TODO(haoliang): Add this pass by default.
pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass());
}
if (pass_config.saved_model_import) {
// This pass does resource analysis of saved model global tensors and marks
// those deemed read-only as immutable.
pass_manager->addPass(
mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());
// This pass marks non-exported functions as symbol visibility 'private'
// those deemed read-only as immutable.
pass_manager->addPass(
mlir::tf_saved_model::
CreateMarkFunctionVisibilityUsingSavedModelLinkagePass());
}
// Enable fusing composite ops that can be lowered to built-in TFLite ops.
if (pass_config.emit_builtin_tflite_ops) {
pass_manager->addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());
}
// The ophint extractions happen before lots of other passes:
// The assumption of ophint-extraction is each ophinted region is a black-box
// and nodes within this black-box is NOT connected to the nodes OUTSIDE the
// black-box.
// Some passes may merge nodes together (such as const nodes), however, this
// will break the ophint-extraction assumption. (The nodes within the black
// box is not isolated anymore).
// So ophint extraction and legalization needs to happen before
// the canonicalization pass.
if (pass_config.emit_builtin_tflite_ops) {
pass_manager->addPass(mlir::TFL::CreateExtractOphintPass());
// Convert composite op pass will happen after ophint extraction pass.
pass_manager->addPass(mlir::TFL::CreateLegalizeOphintFuncOpPass());
}
// Legalize while early to allow further constant folding.
// TODO(jpienaar): This may not actually matter as we do canonicalization
// after the legalize below, for now it needs to be below the above passes
// that work on TF dialect and before inliner so that the function calls in
// body and cond are inlined for optimization.
if (pass_config.legalize_tf_while) {
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFL::CreateLegalizeTFWhilePass());
}
if (pass_config.inline_functions) {
pass_manager->addPass(mlir::createInlinerPass());
}
// TODO(jpienaar): Revise post dialect constants.
pass_manager->addPass(mlir::TF::CreateDecodeConstantPass());
// Canonicalization includes const folding, which is utilized here to optimize
// away ops that can't get constant folded after PrepareTF pass. For example,
// tf.Conv2D is split into tf.Transpose and tfl.Conv2D.
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
// This pass does dead code elimination based on symbol visibility.
pass_manager->addPass(mlir::createSymbolDCEPass());
if (pass_config.saved_model_import) {
// This pass 'freezes' immutable global tensors and inlines them as tf
// constant ops.
pass_manager->addPass(
mlir::tf_saved_model::CreateFreezeGlobalTensorsPass());
}
// The below passes only make sense if Builtin TFLite ops are enabled
// for emission.
if (pass_config.emit_builtin_tflite_ops) {
// Prepare for TFLite dialect, rerun canonicalization, and then legalize to
// the TFLite dialect.
pass_manager->addPass(
mlir::TFL::CreatePrepareTFPass(pass_config.unfold_batch_matmul));
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
pass_manager->addPass(mlir::TFL::CreateLegalizeTFPass());
pass_manager->addPass(mlir::TFL::CreateOptimizePass());
// This pass operates on TensorFlow ops but is triggered after legalization
// so that it can target constants introduced once TensorFlow Identity ops
// are removed during legalization.
pass_manager->addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
// This pass should be always at the end of the floating point model
// conversion. Some TFL ops like unidirectional
// sequence lstm will have stateful operands and some optimization passes
// will merge those operands if they have identical values & types. However,
// it's not desired by TFL. This pass serves as a "fix" pass to split the
// merged inputs until we have 1st class variable support or reuse
// tf.variable to model this.
pass_manager->addPass(mlir::TFL::CreateSplitMergedOperandsPass());
// Run quantization after all the floating point model conversion is
// completed.
if (pass_config.quant_specs.RunPropagationAndRewriteQuantizationPasses()) {
AddQuantizationPasses(pass_config.quant_specs, pass_manager);
}
}
}
} // namespace tensorflow