blob: 790b93b42299bdf2f21b4e5ccd3b9fba76f35b29 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_
#define TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_
#include <cstdio>
#include <cstring>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/core/api/flatbuffer_conversions.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/micro/compatibility.h"
#include "tensorflow/lite/micro/kernels/ethosu.h"
#include "tensorflow/lite/micro/kernels/fully_connected.h"
#include "tensorflow/lite/micro/kernels/micro_ops.h"
#include "tensorflow/lite/micro/micro_op_resolver.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace tflite {
TfLiteRegistration* Register_DETECTION_POSTPROCESS();
template <unsigned int tOpCount>
class MicroMutableOpResolver : public MicroOpResolver {
public:
TF_LITE_REMOVE_VIRTUAL_DELETE
explicit MicroMutableOpResolver(ErrorReporter* error_reporter = nullptr)
: error_reporter_(error_reporter) {}
const TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override {
if (op == BuiltinOperator_CUSTOM) return nullptr;
for (unsigned int i = 0; i < registrations_len_; ++i) {
const TfLiteRegistration& registration = registrations_[i];
if (registration.builtin_code == op) {
return &registration;
}
}
return nullptr;
}
const TfLiteRegistration* FindOp(const char* op) const override {
for (unsigned int i = 0; i < registrations_len_; ++i) {
const TfLiteRegistration& registration = registrations_[i];
if ((registration.builtin_code == BuiltinOperator_CUSTOM) &&
(strcmp(registration.custom_name, op) == 0)) {
return &registration;
}
}
return nullptr;
}
MicroOpResolver::BuiltinParseFunction GetOpDataParser(
BuiltinOperator op) const override {
TFLITE_DCHECK(num_buitin_ops_ <= tOpCount);
for (unsigned int i = 0; i < num_buitin_ops_; ++i) {
if (builtin_codes_[i] == op) return builtin_parsers_[i];
}
return nullptr;
}
// Registers a Custom Operator with the MicroOpResolver.
//
// Only the first call for a given name will be successful. i.e. if this
// function is called again for a previously added Custom Operator, the
// MicroOpResolver will be unchanged and this function will return
// kTfLiteError.
TfLiteStatus AddCustom(const char* name, TfLiteRegistration* registration) {
if (registrations_len_ >= tOpCount) {
if (error_reporter_) {
TF_LITE_REPORT_ERROR(
error_reporter_,
"Couldn't register custom op '%s', resolver size is too small (%d)",
name, tOpCount);
}
return kTfLiteError;
}
if (FindOp(name) != nullptr) {
if (error_reporter_ != nullptr) {
TF_LITE_REPORT_ERROR(error_reporter_,
"Calling AddCustom for the same op more than once "
"is not supported (Op: %s).",
name);
}
return kTfLiteError;
}
TfLiteRegistration* new_registration = &registrations_[registrations_len_];
registrations_len_ += 1;
*new_registration = *registration;
new_registration->builtin_code = BuiltinOperator_CUSTOM;
new_registration->custom_name = name;
return kTfLiteOk;
}
// The Add* functions below add the various Builtin operators to the
// MicroMutableOpResolver object.
TfLiteStatus AddAbs() {
return AddBuiltin(BuiltinOperator_ABS, tflite::ops::micro::Register_ABS(),
ParseAbs);
}
TfLiteStatus AddAdd() {
return AddBuiltin(BuiltinOperator_ADD, tflite::ops::micro::Register_ADD(),
ParseAdd);
}
TfLiteStatus AddArgMax() {
return AddBuiltin(BuiltinOperator_ARG_MAX,
tflite::ops::micro::Register_ARG_MAX(), ParseArgMax);
}
TfLiteStatus AddArgMin() {
return AddBuiltin(BuiltinOperator_ARG_MIN,
tflite::ops::micro::Register_ARG_MIN(), ParseArgMin);
}
TfLiteStatus AddAveragePool2D() {
return AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D,
tflite::ops::micro::Register_AVERAGE_POOL_2D(),
ParsePool);
}
TfLiteStatus AddCeil() {
return AddBuiltin(BuiltinOperator_CEIL, tflite::ops::micro::Register_CEIL(),
ParseCeil);
}
TfLiteStatus AddCircularBuffer() {
return AddCustom("CIRCULAR_BUFFER",
tflite::ops::micro::Register_CIRCULAR_BUFFER());
}
TfLiteStatus AddConcatenation() {
return AddBuiltin(BuiltinOperator_CONCATENATION,
tflite::ops::micro::Register_CONCATENATION(),
ParseConcatenation);
}
TfLiteStatus AddConv2D() {
return AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D(), ParseConv2D);
}
TfLiteStatus AddCos() {
return AddBuiltin(BuiltinOperator_COS, tflite::ops::micro::Register_COS(),
ParseCos);
}
TfLiteStatus AddDepthwiseConv2D() {
return AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D,
Register_DEPTHWISE_CONV_2D(), ParseDepthwiseConv2D);
}
TfLiteStatus AddDequantize() {
return AddBuiltin(BuiltinOperator_DEQUANTIZE,
tflite::ops::micro::Register_DEQUANTIZE(),
ParseDequantize);
}
TfLiteStatus AddDetectionPostprocess() {
return AddCustom("TFLite_Detection_PostProcess",
tflite::Register_DETECTION_POSTPROCESS());
}
TfLiteStatus AddEqual() {
return AddBuiltin(BuiltinOperator_EQUAL,
tflite::ops::micro::Register_EQUAL(), ParseEqual);
}
TfLiteStatus AddEthosU() {
TfLiteRegistration* registration = tflite::Register_ETHOSU();
if (registration) {
return AddCustom(tflite::GetString_ETHOSU(), registration);
}
return kTfLiteOk;
}
TfLiteStatus AddFloor() {
return AddBuiltin(BuiltinOperator_FLOOR,
tflite::ops::micro::Register_FLOOR(), ParseFloor);
}
TfLiteStatus AddFullyConnected(
const TfLiteRegistration& registration = Register_FULLY_CONNECTED()) {
return AddBuiltin(BuiltinOperator_FULLY_CONNECTED, registration,
ParseFullyConnected);
}
TfLiteStatus AddGreater() {
return AddBuiltin(BuiltinOperator_GREATER,
tflite::ops::micro::Register_GREATER(), ParseGreater);
}
TfLiteStatus AddGreaterEqual() {
return AddBuiltin(BuiltinOperator_GREATER_EQUAL,
tflite::ops::micro::Register_GREATER_EQUAL(),
ParseGreaterEqual);
}
TfLiteStatus AddHardSwish() {
return AddBuiltin(BuiltinOperator_HARD_SWISH,
tflite::ops::micro::Register_HARD_SWISH(),
ParseHardSwish);
}
TfLiteStatus AddL2Normalization() {
return AddBuiltin(BuiltinOperator_L2_NORMALIZATION,
tflite::ops::micro::Register_L2_NORMALIZATION(),
ParseL2Normalization);
}
TfLiteStatus AddLess() {
return AddBuiltin(BuiltinOperator_LESS, tflite::ops::micro::Register_LESS(),
ParseLess);
}
TfLiteStatus AddLessEqual() {
return AddBuiltin(BuiltinOperator_LESS_EQUAL,
tflite::ops::micro::Register_LESS_EQUAL(),
ParseLessEqual);
}
TfLiteStatus AddLog() {
return AddBuiltin(BuiltinOperator_LOG, tflite::ops::micro::Register_LOG(),
ParseLog);
}
TfLiteStatus AddLogicalAnd() {
return AddBuiltin(BuiltinOperator_LOGICAL_AND,
tflite::ops::micro::Register_LOGICAL_AND(),
ParseLogicalAnd);
}
TfLiteStatus AddLogicalNot() {
return AddBuiltin(BuiltinOperator_LOGICAL_NOT,
tflite::ops::micro::Register_LOGICAL_NOT(),
ParseLogicalNot);
}
TfLiteStatus AddLogicalOr() {
return AddBuiltin(BuiltinOperator_LOGICAL_OR,
tflite::ops::micro::Register_LOGICAL_OR(),
ParseLogicalOr);
}
TfLiteStatus AddLogistic() {
return AddBuiltin(BuiltinOperator_LOGISTIC,
tflite::ops::micro::Register_LOGISTIC(), ParseLogistic);
}
TfLiteStatus AddMaximum() {
return AddBuiltin(BuiltinOperator_MAXIMUM,
tflite::ops::micro::Register_MAXIMUM(), ParseMaximum);
}
TfLiteStatus AddMaxPool2D() {
return AddBuiltin(BuiltinOperator_MAX_POOL_2D,
tflite::ops::micro::Register_MAX_POOL_2D(), ParsePool);
}
TfLiteStatus AddMean() {
return AddBuiltin(BuiltinOperator_MEAN, tflite::ops::micro::Register_MEAN(),
ParseReducer);
}
TfLiteStatus AddMinimum() {
return AddBuiltin(BuiltinOperator_MINIMUM,
tflite::ops::micro::Register_MINIMUM(), ParseMinimum);
}
TfLiteStatus AddMul() {
return AddBuiltin(BuiltinOperator_MUL, tflite::ops::micro::Register_MUL(),
ParseMul);
}
TfLiteStatus AddNeg() {
return AddBuiltin(BuiltinOperator_NEG, tflite::ops::micro::Register_NEG(),
ParseNeg);
}
TfLiteStatus AddNotEqual() {
return AddBuiltin(BuiltinOperator_NOT_EQUAL,
tflite::ops::micro::Register_NOT_EQUAL(), ParseNotEqual);
}
TfLiteStatus AddPack() {
return AddBuiltin(BuiltinOperator_PACK, tflite::ops::micro::Register_PACK(),
ParsePack);
}
TfLiteStatus AddPad() {
return AddBuiltin(BuiltinOperator_PAD, tflite::ops::micro::Register_PAD(),
ParsePad);
}
TfLiteStatus AddPadV2() {
return AddBuiltin(BuiltinOperator_PADV2,
tflite::ops::micro::Register_PADV2(), ParsePadV2);
}
TfLiteStatus AddPrelu() {
return AddBuiltin(BuiltinOperator_PRELU,
tflite::ops::micro::Register_PRELU(), ParsePrelu);
}
TfLiteStatus AddQuantize() {
return AddBuiltin(BuiltinOperator_QUANTIZE, Register_QUANTIZE(),
ParseQuantize);
}
TfLiteStatus AddReduceMax() {
return AddBuiltin(BuiltinOperator_REDUCE_MAX,
tflite::ops::micro::Register_REDUCE_MAX(), ParseReducer);
}
TfLiteStatus AddRelu() {
return AddBuiltin(BuiltinOperator_RELU, tflite::ops::micro::Register_RELU(),
ParseRelu);
}
TfLiteStatus AddRelu6() {
return AddBuiltin(BuiltinOperator_RELU6,
tflite::ops::micro::Register_RELU6(), ParseRelu6);
}
TfLiteStatus AddReshape() {
return AddBuiltin(BuiltinOperator_RESHAPE,
tflite::ops::micro::Register_RESHAPE(), ParseReshape);
}
TfLiteStatus AddResizeNearestNeighbor() {
return AddBuiltin(BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
tflite::ops::micro::Register_RESIZE_NEAREST_NEIGHBOR(),
ParseResizeNearestNeighbor);
}
TfLiteStatus AddRound() {
return AddBuiltin(BuiltinOperator_ROUND,
tflite::ops::micro::Register_ROUND(), ParseRound);
}
TfLiteStatus AddRsqrt() {
return AddBuiltin(BuiltinOperator_RSQRT,
tflite::ops::micro::Register_RSQRT(), ParseRsqrt);
}
TfLiteStatus AddShape() {
return AddBuiltin(BuiltinOperator_SHAPE, Register_SHAPE(), ParseShape);
}
TfLiteStatus AddSin() {
return AddBuiltin(BuiltinOperator_SIN, tflite::ops::micro::Register_SIN(),
ParseSin);
}
TfLiteStatus AddSoftmax() {
return AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX(),
ParseSoftmax);
}
TfLiteStatus AddSplit() {
return AddBuiltin(BuiltinOperator_SPLIT,
tflite::ops::micro::Register_SPLIT(), ParseSplit);
}
TfLiteStatus AddSplitV() {
return AddBuiltin(BuiltinOperator_SPLIT_V,
tflite::ops::micro::Register_SPLIT_V(), ParseSplitV);
}
TfLiteStatus AddSqrt() {
return AddBuiltin(BuiltinOperator_SQRT, tflite::ops::micro::Register_SQRT(),
ParseSqrt);
}
TfLiteStatus AddSquare() {
return AddBuiltin(BuiltinOperator_SQUARE,
tflite::ops::micro::Register_SQUARE(), ParseSquare);
}
TfLiteStatus AddStridedSlice() {
return AddBuiltin(BuiltinOperator_STRIDED_SLICE,
tflite::ops::micro::Register_STRIDED_SLICE(),
ParseStridedSlice);
}
TfLiteStatus AddSub() {
return AddBuiltin(BuiltinOperator_SUB, tflite::ops::micro::Register_SUB(),
ParseSub);
}
TfLiteStatus AddSvdf() {
return AddBuiltin(BuiltinOperator_SVDF, Register_SVDF(), ParseSvdf);
}
TfLiteStatus AddTanh() {
return AddBuiltin(BuiltinOperator_TANH, tflite::ops::micro::Register_TANH(),
ParseTanh);
}
TfLiteStatus AddUnpack() {
return AddBuiltin(BuiltinOperator_UNPACK,
tflite::ops::micro::Register_UNPACK(), ParseUnpack);
}
unsigned int GetRegistrationLength() { return registrations_len_; }
private:
TfLiteStatus AddBuiltin(tflite::BuiltinOperator op,
const TfLiteRegistration& registration,
MicroOpResolver::BuiltinParseFunction parser) {
if (op == BuiltinOperator_CUSTOM) {
if (error_reporter_ != nullptr) {
TF_LITE_REPORT_ERROR(error_reporter_,
"Invalid parameter BuiltinOperator_CUSTOM to the "
"AddBuiltin function.");
}
return kTfLiteError;
}
if (FindOp(op) != nullptr) {
if (error_reporter_ != nullptr) {
TF_LITE_REPORT_ERROR(error_reporter_,
"Calling AddBuiltin with the same op more than "
"once is not supported (Op: #%d).",
op);
}
return kTfLiteError;
}
if (registrations_len_ >= tOpCount) {
if (error_reporter_) {
TF_LITE_REPORT_ERROR(error_reporter_,
"Couldn't register builtin op #%d, resolver size "
"is too small (%d).",
op, tOpCount);
}
return kTfLiteError;
}
registrations_[registrations_len_] = registration;
// Strictly speaking, the builtin_code is not necessary for TFLM but filling
// it in regardless.
registrations_[registrations_len_].builtin_code = op;
registrations_len_++;
builtin_codes_[num_buitin_ops_] = op;
builtin_parsers_[num_buitin_ops_] = parser;
num_buitin_ops_++;
return kTfLiteOk;
}
TfLiteRegistration registrations_[tOpCount];
unsigned int registrations_len_ = 0;
// Arrays (and counter) to store the builtin codes and their corresponding
// parse functions as these are registered with the Op Resolver.
BuiltinOperator builtin_codes_[tOpCount];
MicroOpResolver::BuiltinParseFunction builtin_parsers_[tOpCount];
unsigned int num_buitin_ops_ = 0;
ErrorReporter* error_reporter_;
};
}; // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_