blob: 020792392a2a37fe275458ec4b2bb330654959c8 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/utils/nms_utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
namespace mlir {
namespace TFL {
namespace {
// TODO(b/162842801): Consolidate all util definitions of kTFImplements.
constexpr char kTFImplements[] = "tf._implements";
constexpr char kCustomSSDPostprocessing[] = "TFLite_Detection_PostProcess";
constexpr char kTfNMSPadded[] = "non_max_suppression_padded_v2";
inline OpaqueElementsAttr CustomOption(OpBuilder* builder,
const std::string& content) {
ShapedType type = RankedTensorType::get(
{static_cast<int64_t>(content.size())}, builder->getIntegerType(8));
return OpaqueElementsAttr::get(builder->getContext()->getLoadedDialect("tfl"),
type,
StringRef(content.data(), content.size()));
}
} // namespace
void ConvertNMSPaddedFunc::RewriteFunc() {
func_->setAttr(kTFImplements,
StringAttr::get(func_.getContext(), kTfNMSPadded));
Value boxes = func_.getArgument(0);
Value scores = func_.getArgument(1);
Value max_output_size = func_.getArgument(2);
Value iou_threshold = func_.getArgument(3);
Value score_threshold = func_.getArgument(4);
auto output_type0 = func_.getFunctionType().getResult(0);
auto output_type1 = func_.getFunctionType().getResult(1);
OpBuilder builder(func_.getBody());
auto op = builder.create<mlir::TFL::NonMaxSuppressionV4Op>(
func_.getLoc(), output_type0, output_type1, boxes, scores,
max_output_size, iou_threshold, score_threshold);
builder.create<mlir::func::ReturnOp>(func_.getLoc(), op.getResults());
}
LogicalResult ConvertNMSPaddedFunc::VerifySignature() {
// Verify high-level function signature.
// Relevant argument characteristics are checked by the TFL op definition.
if (func_.getNumArguments() < 5) {
return func_.emitWarning()
<< "Invalid number of arguments to "
"non_max_suppression_padded_v2 (need at least 5): "
<< func_.getNumArguments();
}
if (func_.getFunctionType().getNumResults() != 2) {
return func_.emitWarning() << "Invalid number of results from "
"non_max_suppression_padded_v2 (need 2): "
<< func_.getFunctionType().getNumResults();
}
// The TFLite fused op does not support batching yet.
// TODO(b/158709815): Add support for batches with padded NMS.
auto boxes_type =
func_.getFunctionType().getInput(0).dyn_cast<RankedTensorType>();
if (boxes_type == nullptr || !boxes_type.hasRank() ||
boxes_type.getRank() != 2) {
return func_.emitWarning() << "TFLite does not support batched input for "
"non_max_suppression_padded";
}
return success();
}
LogicalResult ConvertSSDPostProcessFunc::RewriteFunc() {
func_.eraseBody();
func_.addEntryBlock();
func_->setAttr(kTFImplements,
StringAttr::get(func_.getContext(), kCustomSSDPostprocessing));
OpBuilder builder(func_.getBody());
std::string custom_option_buffer;
if (failed(CreateNMSCustomOptions(func_, attr_.getAttrs(),
custom_option_buffer))) {
return failure();
}
auto op = builder.create<CustomOp>(
func_.getLoc(), func_.getFunctionType().getResults(),
func_.getArguments(), kCustomSSDPostprocessing,
CustomOption(&builder, custom_option_buffer));
builder.create<func::ReturnOp>(func_.getLoc(), op.getResults());
return success();
}
LogicalResult ConvertSSDPostProcessFunc::CreateNMSCustomOptions(
func::FuncOp func, DictionaryAttr attrs,
std::string& custom_option_buffer) {
flexbuffers::Builder fbb;
size_t start_map = fbb.StartMap();
if (failed(AddIntAttr(func, attrs, "max_detections", &fbb)) ||
failed(AddIntAttr(func, attrs, "max_classes_per_detection", &fbb)) ||
failed(AddIntAttr(func, attrs, "num_classes", &fbb)) ||
failed(AddFloatAttr(func, attrs, "nms_score_threshold", &fbb)) ||
failed(AddFloatAttr(func, attrs, "nms_iou_threshold", &fbb)) ||
failed(AddFloatAttr(func, attrs, "y_scale", &fbb)) ||
failed(AddFloatAttr(func, attrs, "x_scale", &fbb)) ||
failed(AddFloatAttr(func, attrs, "h_scale", &fbb)) ||
failed(AddFloatAttr(func, attrs, "w_scale", &fbb)))
return failure();
auto use_regular_nms =
attrs.get("use_regular_nms").dyn_cast_or_null<BoolAttr>();
if (!use_regular_nms) {
return func.emitError()
<< "use_regular_nms attribute is not set or not a bool";
}
fbb.Int("use_regular_nms", use_regular_nms.getValue());
fbb.EndMap(start_map);
fbb.Finish();
custom_option_buffer.assign(fbb.GetBuffer().begin(), fbb.GetBuffer().end());
return success();
}
LogicalResult ConvertSSDPostProcessFunc::AddIntAttr(
func::FuncOp func, DictionaryAttr attrs, const std::string& attribute,
flexbuffers::Builder* builder) {
auto int_attr = attrs.get(attribute).dyn_cast_or_null<IntegerAttr>();
if (!int_attr) {
return func.emitError()
<< attribute.c_str() << " attribute is not set or not an integer";
}
builder->Int(attribute.c_str(), int_attr.getInt());
return success();
}
LogicalResult ConvertSSDPostProcessFunc::AddFloatAttr(
func::FuncOp func, DictionaryAttr attrs, const std::string& attribute,
flexbuffers::Builder* builder) {
auto float_attr = attrs.get(attribute).dyn_cast_or_null<FloatAttr>();
if (!float_attr) {
return func.emitError()
<< attribute.c_str() << " attribute is not set or not a float";
}
builder->Float(attribute.c_str(), float_attr.getValue().convertToFloat());
return success();
}
LogicalResult ConvertSSDPostProcessFunc::HasIntAttr(
func::FuncOp func, DictionaryAttr attrs, const std::string& attribute) {
auto int_attr = attrs.get(attribute).dyn_cast_or_null<IntegerAttr>();
if (!int_attr) {
return func.emitWarning()
<< attribute.c_str() << " attribute is not set or not an integer";
}
return success();
}
LogicalResult ConvertSSDPostProcessFunc::HasFloatAttr(
func::FuncOp func, DictionaryAttr attrs, const std::string& attribute) {
auto float_attr = attrs.get(attribute).dyn_cast_or_null<FloatAttr>();
if (!float_attr) {
return func.emitWarning()
<< attribute.c_str() << " attribute is not set or not a float";
}
return success();
}
LogicalResult ConvertSSDPostProcessFunc::VerifySignature() {
// Verify high-level function signature.
if (func_.getNumArguments() != 3) {
return func_.emitWarning()
<< "Invalid number of arguments to " << kCustomSSDPostprocessing
<< ": " << func_.getNumArguments();
}
if (func_.getFunctionType().getNumResults() != 4) {
return func_.emitWarning()
<< "Invalid number of results from " << kCustomSSDPostprocessing
<< ": " << func_.getFunctionType().getNumResults();
}
auto attrs = attr_.getAttrs();
if (failed(HasIntAttr(func_, attrs, "max_detections")) ||
failed(HasIntAttr(func_, attrs, "max_classes_per_detection")) ||
failed(HasIntAttr(func_, attrs, "num_classes")) ||
failed(HasFloatAttr(func_, attrs, "nms_score_threshold")) ||
failed(HasFloatAttr(func_, attrs, "nms_iou_threshold")) ||
failed(HasFloatAttr(func_, attrs, "y_scale")) ||
failed(HasFloatAttr(func_, attrs, "x_scale")) ||
failed(HasFloatAttr(func_, attrs, "h_scale")) ||
failed(HasFloatAttr(func_, attrs, "w_scale"))) {
return failure();
}
return success();
}
} // namespace TFL
} // namespace mlir