| /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h" |
| |
| #include <ostream> |
| #include <string> |
| #include <utility> |
| |
| #include "llvm/ADT/None.h" |
| #include "llvm/Support/ToolOutputFile.h" |
| #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
| #include "mlir/IR/MLIRContext.h" // from @llvm-project |
| #include "mlir/Pass/Pass.h" // from @llvm-project |
| #include "mlir/Support/FileUtilities.h" // from @llvm-project |
| #include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project |
| #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" |
| #include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h" |
| #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" |
| #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" |
| #include "tensorflow/compiler/mlir/lite/transforms/passes.h" |
| #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" |
| #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" |
| #include "tensorflow/core/framework/graph.pb.h" |
| #include "tensorflow/core/framework/types.pb.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/platform/status.h" |
| #include "tensorflow/core/protobuf/graph_debug_info.pb.h" |
| #include "tensorflow/lite/toco/model_flags.pb.h" |
| #include "tensorflow/lite/toco/toco_flags.pb.h" |
| #include "tensorflow/lite/toco/types.pb.h" |
| #include "tensorflow/stream_executor/lib/statusor.h" |
| |
| namespace tensorflow { |
| Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, |
| const toco::TocoFlags& toco_flags, |
| const GraphDebugInfo& debug_info, |
| const GraphDef& input, |
| string* result) { |
| using ::tflite::optimize::ReducedPrecisionSupport; |
| mlir::MLIRContext context; |
| GraphImportConfig specs; |
| mlir::quant::QuantizationSpecs quant_specs; |
| |
| // Parse input arrays. |
| std::vector<string> node_names; |
| std::vector<string> node_dtypes; |
| std::vector<llvm::Optional<std::vector<int>>> node_shapes; |
| std::vector<llvm::Optional<double>> node_mins; |
| std::vector<llvm::Optional<double>> node_maxs; |
| |
| // Populate quantization specs. |
| TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs( |
| model_flags, toco_flags, &quant_specs, &node_names, &node_dtypes, |
| &node_shapes, &node_mins, &node_maxs)); |
| |
| TF_RETURN_IF_ERROR(tensorflow::ParseInputArrayInfo( |
| node_names, node_dtypes, node_shapes, &specs.inputs)); |
| |
| // Parse output arrays. |
| std::vector<string> output_arrays(model_flags.output_arrays().begin(), |
| model_flags.output_arrays().end()); |
| TF_RETURN_IF_ERROR( |
| tensorflow::ParseOutputArrayInfo(output_arrays, &specs.outputs)); |
| |
| // Parse control output arrays. |
| std::vector<string> control_output_arrays( |
| model_flags.control_output_arrays().begin(), |
| model_flags.control_output_arrays().end()); |
| TF_RETURN_IF_ERROR(tensorflow::ParseOutputArrayInfo(control_output_arrays, |
| &specs.control_outputs)); |
| |
| specs.prune_unused_nodes = true; |
| specs.convert_legacy_fed_inputs = true; |
| specs.graph_as_function = false; |
| specs.upgrade_legacy = true; |
| specs.unconditionally_use_set_output_shapes = true; |
| internal::WarningUnusedFlags(model_flags, toco_flags); |
| |
| // Register all custom ops, including user-specified custom ops. |
| TF_RETURN_IF_ERROR(internal::RegisterAllCustomOps(toco_flags)); |
| |
| TF_ASSIGN_OR_RETURN( |
| auto module, ConvertGraphdefToMlir(input, debug_info, specs, &context)); |
| |
| mlir::TFL::PassConfig pass_config(quant_specs); |
| bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops(); |
| pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops; |
| pass_config.unfold_batch_matmul = toco_flags.unfold_batchmatmul(); |
| pass_config.lower_tensor_list_ops = toco_flags.lower_tensor_list_ops(); |
| // Disable the unfolding of the 16x16 TF::BatchMatMulOp to avoid the |
| // conversion to an unsupported 16x16 TFL::FullyConnectedOp. |
| if (toco_flags.inference_type() == toco::IODataType::QUANTIZED_INT16) { |
| pass_config.unfold_batch_matmul = false; |
| } |
| pass_config.unfold_large_splat_constant = |
| toco_flags.unfold_large_splat_constant(); |
| pass_config.enable_dynamic_update_slice = |
| toco_flags.enable_dynamic_update_slice(); |
| pass_config.preserve_assert_op = toco_flags.preserve_assert_op(); |
| pass_config.guarantee_all_funcs_one_use = |
| toco_flags.guarantee_all_funcs_one_use(); |
| |
| return internal::ConvertMLIRToTFLiteFlatBuffer( |
| model_flags, toco_flags, std::move(module), pass_config, |
| /*saved_model_tags=*/{}, result, |
| /*session=*/llvm::None); |
| } |
| |
| } // namespace tensorflow |