blob: 6dd44e666fbb4b82ef836b45f5680221d3e928ba [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h"
#include <ostream>
#include <utility>
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/FileUtilities.h" // from @llvm-project
#include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
#include "tensorflow/lite/toco/model_flags.pb.h"
#include "tensorflow/lite/toco/toco_flags.pb.h"
#include "tensorflow/lite/toco/types.pb.h"
#include "tensorflow/stream_executor/lib/statusor.h"
using stream_executor::port::StatusOr;
namespace tensorflow {
namespace internal {
namespace {
// Op def string for TFLite_Detection_PostProcess Op.
const char kDetectionPostProcessOp[] =
"name: 'TFLite_Detection_PostProcess' input_arg: { name: "
"'raw_outputs/box_encodings' type: DT_FLOAT } input_arg: { name: "
"'raw_outputs/class_predictions' type: DT_FLOAT } input_arg: { name: "
"'anchors' type: DT_FLOAT } output_arg: { name: "
"'TFLite_Detection_PostProcess' type: DT_FLOAT } output_arg: { name: "
"'TFLite_Detection_PostProcess:1' type: DT_FLOAT } output_arg: { name: "
"'TFLite_Detection_PostProcess:2' type: DT_FLOAT } output_arg: { name: "
"'TFLite_Detection_PostProcess:3' type: DT_FLOAT } attr : { name: "
"'h_scale' type: 'float'} attr : { name: 'max_classes_per_detection' "
"type: 'int'} attr : { name: 'max_detections' type: 'int'} attr : { "
"name: 'nms_iou_threshold' type: 'float'} attr : { name: "
"'nms_score_threshold' type: 'float'} attr : { name: 'num_classes' type: "
"'int'} attr : { name: 'w_scale' type: 'float'} attr : { name: 'x_scale' "
"type: 'float'} attr : { name: 'y_scale' type: 'float'} attr { name: "
"'detections_per_class' type: 'int' default_value { i : 100 }} attr { "
"name: 'use_regular_nms' type: 'bool' default_value { b : false }}";
const char kUnidirectionalSequenceLstmOp[] =
"name: 'UnidirectionalSequenceLstm' input_arg: {name: 'Input' type: "
"DT_FLOAT} input_arg: { name: 'InputToInputWeights' type: DT_FLOAT } "
"input_arg: { name: 'InputToForgetWeights' type: DT_FLOAT } input_arg: { "
"name: 'InputToCellWeights' type: DT_FLOAT} input_arg: { name: "
"'InputToOutputWeights' type: DT_FLOAT } input_arg: { name: "
"'RecurrentToInputWeights' type: DT_FLOAT} input_arg: { name: "
"'RecurrentToForgetWeights' type: DT_FLOAT} input_arg: { name: "
"'RecurrentToCellWeights' type: DT_FLOAT } input_arg: { name: "
"'RecurrentToOutputWeights' type: DT_FLOAT } input_arg: { name: "
"'CellToInputWeights' type: DT_FLOAT} input_arg: { name: "
"'CellToForgetWeights' type: DT_FLOAT } input_arg: { name: "
"'CellToOutputWeights' type: DT_FLOAT } input_arg: { name: 'InputGateBias' "
"type: DT_FLOAT } input_arg: { name: 'ForgetGateBias' type: DT_FLOAT } "
"input_arg: { name: 'kCellGateBias' type: DT_FLOAT } input_arg: { name: "
"'OutputGateBias' type: DT_FLOAT } input_arg: { name: 'ProjectionWeights' "
"type: DT_FLOAT } input_arg: { name: 'ProjectionBias' type: DT_FLOAT } "
"input_arg: { name: 'InputActivationState' type: DT_FLOAT} input_arg: { "
"name: 'InputCellStateTensor' type: DT_FLOAT } "
"output_arg: { name: 'Concat' type: DT_FLOAT} "
"output_arg: { name: "
"'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: DT_FLOAT} "
"attr : { name: '_tflite_input_indices' type: 'list(int)'}";
const char kUnidirectionalSequenceRnnOp[] =
"name: 'UnidirectionalSequenceRnn' input_arg: {name: 'Input' type: "
"DT_FLOAT} input_arg: { name: 'Weights' type: DT_FLOAT } "
"input_arg: { name: 'RecurrentWeights' type: DT_FLOAT } input_arg: { "
"name: 'Bias' type: DT_FLOAT} "
"input_arg: { name: 'HiddenState' type: DT_FLOAT} "
"output_arg: { name: "
"'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: "
"DT_FLOAT} "
"attr : { name: '_tflite_input_indices' type: 'list(int)'}";
// Converts the toco::IODataType to tensorflow::DataType. Only contains the
// conversion mapping for constants defined in TFLite Python API.
DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
switch (dtype) {
case toco::IODataType::FLOAT:
return DT_FLOAT;
case toco::IODataType::FLOAT16:
return DT_HALF;
case toco::IODataType::FLOAT64:
return DT_DOUBLE;
case toco::IODataType::QUANTIZED_UINT8:
return DT_QUINT8;
case toco::IODataType::INT8:
return DT_QINT8;
case toco::IODataType::INT32:
return DT_INT32;
case toco::IODataType::INT64:
return DT_INT64;
case toco::IODataType::STRING:
return DT_STRING;
case toco::IODataType::BOOL:
return DT_BOOL;
default:
return DT_INVALID;
}
}
StatusOr<std::pair<double, double>> InputStatsToMinMax(double mean, double std,
DataType type) {
// Only qint8 and quint8 are considered here.
double qmin, qmax;
if (type == DT_QUINT8) {
qmin = 0.0;
qmax = 255.0;
} else if (type == DT_QINT8) {
qmin = -128.0;
qmax = 127.0;
} else {
return errors::InvalidArgument("Only int8 and uint8 are considered.");
}
return std::make_pair((qmin - mean) / std, (qmax - mean) / std);
}
Status RegisterCustomBuiltinOps(const std::vector<string> extra_tf_opdefs) {
for (const auto& tf_opdefs_string : extra_tf_opdefs) {
tensorflow::OpDef opdef;
if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string,
&opdef)) {
return errors::InvalidArgument("fail to parse extra OpDef");
}
// Make sure the op is not already registered. If registered continue.
const OpRegistrationData* op_reg =
tensorflow::OpRegistry::Global()->LookUp(opdef.name());
if (op_reg) continue;
tensorflow::OpRegistry::Global()->Register(
[opdef](tensorflow::OpRegistrationData* op_reg_data) -> Status {
*op_reg_data = tensorflow::OpRegistrationData(opdef);
return Status::OK();
});
}
return Status::OK();
}
} // namespace
Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags) {
// Register any custom OpDefs.
std::vector<string> extra_tf_opdefs(toco_flags.custom_opdefs().begin(),
toco_flags.custom_opdefs().end());
extra_tf_opdefs.push_back(kDetectionPostProcessOp);
extra_tf_opdefs.push_back(kUnidirectionalSequenceLstmOp);
extra_tf_opdefs.push_back(kUnidirectionalSequenceRnnOp);
return RegisterCustomBuiltinOps(extra_tf_opdefs);
}
Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags,
const toco::TocoFlags& toco_flags,
mlir::TFL::QuantizationSpecs* quant_specs,
std::vector<string>* node_names,
std::vector<string>* node_dtypes,
std::vector<std::vector<int>>* node_shapes,
std::vector<double>* node_mins,
std::vector<double>* node_maxs) {
quant_specs->inference_input_type =
ConvertIODataTypeToDataType(toco_flags.inference_input_type());
tensorflow::DataType inference_type =
ConvertIODataTypeToDataType(toco_flags.inference_type());
// Use non-float flag `inference_input_type` to override the `inference_type`
// because we have to apply quantization to satisfy that.
if (quant_specs->inference_input_type != tensorflow::DT_FLOAT) {
inference_type = quant_specs->inference_input_type;
}
for (auto& flag : model_flags.input_arrays()) {
node_names->push_back(flag.name());
// TOCO doesn't required `data_type` to be filled for every input.
// If it's not filled, make it an empty string so the importer will use
// the data type in the NodeDef.
auto toco_data_type = flag.data_type();
if (toco_data_type == ::toco::IODataType::IO_DATA_TYPE_UNKNOWN) {
node_dtypes->push_back("");
} else {
node_dtypes->push_back(
DataType_Name(ConvertIODataTypeToDataType(toco_data_type)));
}
node_shapes->push_back(std::vector<int>(flag.shape().dims().begin(),
flag.shape().dims().end()));
// Currently, only UINT8 and INT8 require inputs stats
if (inference_type == DT_QINT8 || inference_type == DT_QUINT8) {
TF_ASSIGN_OR_RETURN(
auto min_max, InputStatsToMinMax(flag.mean_value(), flag.std_value(),
inference_type));
node_mins->push_back(min_max.first);
node_maxs->push_back(min_max.second);
}
}
if (mlir::TFL::GetInputNodeQuantSpecs(*node_names, *node_mins, *node_maxs,
inference_type, quant_specs)) {
return errors::InvalidArgument("Failed to get input quant spec.");
}
// Some extra flag related to post training quantization. If post-training
// quantization is enabled, `inference_type` and `inference_input_type` are
// not used by MLIR passes.
if (toco_flags.post_training_quantize()) {
quant_specs->weight_quantization = true;
if (toco_flags.quantize_to_float16()) {
quant_specs->inference_type = tensorflow::DT_HALF;
quant_specs->inference_input_type = tensorflow::DT_HALF;
} else {
quant_specs->inference_type = tensorflow::DT_QINT8;
quant_specs->inference_input_type = tensorflow::DT_QINT8;
}
}
// Other flags.
if (toco_flags.has_default_ranges_min()) {
quant_specs->default_ranges.first = toco_flags.default_ranges_min();
}
if (toco_flags.has_default_ranges_max()) {
quant_specs->default_ranges.second = toco_flags.default_ranges_max();
}
return ::tensorflow::Status::OK();
}
// Dumps the op graph of the `module` to `filename` in DOT format.
Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
std::string error_message;
auto output = mlir::openOutputFile(filename, &error_message);
if (!error_message.empty()) {
return errors::InvalidArgument("Failed to open file in %s.", filename);
}
mlir::PassManager pm(module.getContext());
pm.addPass(mlir::createPrintOpGraphPass(output->os()));
if (failed(pm.run(module))) {
return errors::Unknown("Failed to dump Op Graph from MLIR module.");
}
output->keep();
return Status::OK();
}
Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
mlir::OwningModuleRef module,
const mlir::TFL::PassConfig& pass_config,
string* result) {
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
bool emit_select_tf_ops = toco_flags.enable_select_tf_ops();
bool emit_custom_ops = toco_flags.allow_custom_ops();
if (toco_flags.has_dump_graphviz_dir()) {
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
module.get(),
// rename once we enable the new converter feature flag.
absl::StrCat(toco_flags.dump_graphviz_dir(), "/toco_AT_IMPORT.dot")));
}
mlir::PassManager pm(module->getContext());
tensorflow::AddTFToTFLConversionPasses(pass_config, &pm);
// Convert back to outlined while format for export back to flatbuffer.
if (pass_config.legalize_tf_while) {
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
}
pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, pass_config.quant_specs, result,
&pm);
if (toco_flags.has_dump_graphviz_dir()) {
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
// rename once we enable the new converter feature flag.
module.get(), absl::StrCat(toco_flags.dump_graphviz_dir(),
"/toco_AFTER_TRANSFORMATIONS.dot")));
}
return status;
}
void WarningUnusedFlags(const toco::ModelFlags& model_flags,
const toco::TocoFlags& toco_flags) {
if (toco_flags.output_format()) {
LOG(WARNING) << "Ignored output_format.";
}
if (toco_flags.drop_control_dependency()) {
LOG(WARNING) << "Ignored drop_control_dependency.";
}
if (toco_flags.reorder_across_fake_quant()) {
LOG(WARNING) << "Ignored reorder_across_fake_quant.";
}
if (model_flags.change_concat_input_ranges()) {
LOG(WARNING) << "Ignored change_concat_input_ranges.";
}
if (toco_flags.dump_graphviz_include_video()) {
LOG(WARNING) << "Ignored dump_graphviz_video.";
}
if (model_flags.allow_nonexistent_arrays()) {
LOG(WARNING) << "Allow allow_nonexistent_arrays.";
}
}
} // namespace internal
} // namespace tensorflow