Separate out parse functionality into helper functions.
Ops in this change:
* Abs
* Add
* ArgMax
* ArgMin
PiperOrigin-RevId: 317908035
Change-Id: I6c33bd83c987c92b71992c6c113d8678bc9d35d8
diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc
index 73d785b..c496c45 100644
--- a/tensorflow/lite/core/api/flatbuffer_conversions.cc
+++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc
@@ -177,6 +177,91 @@
}
}
+// We have this parse function instead of directly returning kTfLiteOk from the
+// switch-case in ParseOpData because this function is used as part of the
+// selective registration for the OpResolver implementation in micro.
+TfLiteStatus ParseAbs(const Operator*, BuiltinOperator, ErrorReporter*,
+ BuiltinDataAllocator*, void**) {
+ return kTfLiteOk;
+}
+
+TfLiteStatus ParseAdd(const Operator* op, BuiltinOperator,
+ ErrorReporter* error_reporter,
+ BuiltinDataAllocator* allocator, void** builtin_data) {
+ CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
+
+ SafeBuiltinDataAllocator safe_allocator(allocator);
+ std::unique_ptr<TfLiteAddParams, SafeBuiltinDataAllocator::BuiltinDataDeleter>
+ params = safe_allocator.Allocate<TfLiteAddParams>();
+ TF_LITE_ENSURE(error_reporter, params != nullptr);
+
+ const AddOptions* schema_params = op->builtin_options_as_AddOptions();
+
+ if (schema_params != nullptr) {
+ params->activation =
+ ConvertActivation(schema_params->fused_activation_function());
+ } else {
+ // TODO(b/157480169): We should either return kTfLiteError or fill in some
+ // reasonable defaults in the params struct. We are not doing so until we
+ // better undertand the ramifications of changing the legacy behavior.
+ }
+
+ *builtin_data = params.release();
+ return kTfLiteOk;
+}
+
+TfLiteStatus ParseArgMax(const Operator* op, BuiltinOperator,
+ ErrorReporter* error_reporter,
+ BuiltinDataAllocator* allocator, void** builtin_data) {
+ CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
+
+ SafeBuiltinDataAllocator safe_allocator(allocator);
+ std::unique_ptr<TfLiteArgMaxParams,
+ SafeBuiltinDataAllocator::BuiltinDataDeleter>
+ params = safe_allocator.Allocate<TfLiteArgMaxParams>();
+ TF_LITE_ENSURE(error_reporter, params != nullptr);
+
+ const ArgMaxOptions* schema_params = op->builtin_options_as_ArgMaxOptions();
+
+ if (schema_params != nullptr) {
+ TF_LITE_ENSURE_STATUS(ConvertTensorType(
+ schema_params->output_type(), ¶ms->output_type, error_reporter));
+ } else {
+ // TODO(b/157480169): We should either return kTfLiteError or fill in some
+ // reasonable defaults in the params struct. We are not doing so until we
+ // better undertand the ramifications of changing the legacy behavior.
+ }
+
+ *builtin_data = params.release();
+ return kTfLiteOk;
+}
+
+TfLiteStatus ParseArgMin(const Operator* op, BuiltinOperator,
+ ErrorReporter* error_reporter,
+ BuiltinDataAllocator* allocator, void** builtin_data) {
+ CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
+
+ SafeBuiltinDataAllocator safe_allocator(allocator);
+ std::unique_ptr<TfLiteArgMinParams,
+ SafeBuiltinDataAllocator::BuiltinDataDeleter>
+ params = safe_allocator.Allocate<TfLiteArgMinParams>();
+ TF_LITE_ENSURE(error_reporter, params != nullptr);
+
+ const ArgMinOptions* schema_params = op->builtin_options_as_ArgMinOptions();
+
+ if (schema_params != nullptr) {
+ TF_LITE_ENSURE_STATUS(ConvertTensorType(
+ schema_params->output_type(), ¶ms->output_type, error_reporter));
+ } else {
+ // TODO(b/157480169): We should either return kTfLiteError or fill in some
+ // reasonable defaults in the params struct. We are not doing so until we
+ // better undertand the ramifications of changing the legacy behavior.
+ }
+
+ *builtin_data = params.release();
+ return kTfLiteOk;
+}
+
TfLiteStatus ParseConv2D(const Operator* op, BuiltinOperator,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data) {
@@ -430,6 +515,22 @@
SafeBuiltinDataAllocator safe_allocator(allocator);
*builtin_data = nullptr;
switch (op_type) {
+ case BuiltinOperator_ABS: {
+ return ParseAbs(op, op_type, error_reporter, allocator, builtin_data);
+ }
+
+ case BuiltinOperator_ADD: {
+ return ParseAdd(op, op_type, error_reporter, allocator, builtin_data);
+ }
+
+ case BuiltinOperator_ARG_MAX: {
+ return ParseArgMax(op, op_type, error_reporter, allocator, builtin_data);
+ }
+
+ case BuiltinOperator_ARG_MIN: {
+ return ParseArgMin(op, op_type, error_reporter, allocator, builtin_data);
+ }
+
case BuiltinOperator_CONV_2D: {
return ParseConv2D(op, op_type, error_reporter, allocator, builtin_data);
}
@@ -586,16 +687,6 @@
*builtin_data = params.release();
return kTfLiteOk;
}
- case BuiltinOperator_ADD: {
- auto params = safe_allocator.Allocate<TfLiteAddParams>();
- TF_LITE_ENSURE(error_reporter, params != nullptr);
- if (const auto* schema_params = op->builtin_options_as_AddOptions()) {
- params->activation =
- ConvertActivation(schema_params->fused_activation_function());
- }
- *builtin_data = params.release();
- return kTfLiteOk;
- }
case BuiltinOperator_DIV: {
auto params = safe_allocator.Allocate<TfLiteDivParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
@@ -838,28 +929,6 @@
*builtin_data = params.release();
return kTfLiteOk;
}
- case BuiltinOperator_ARG_MAX: {
- auto params = safe_allocator.Allocate<TfLiteArgMaxParams>();
- TF_LITE_ENSURE(error_reporter, params != nullptr);
- if (const auto* schema_params = op->builtin_options_as_ArgMaxOptions()) {
- TF_LITE_ENSURE_STATUS(ConvertTensorType(schema_params->output_type(),
- ¶ms->output_type,
- error_reporter));
- }
- *builtin_data = params.release();
- return kTfLiteOk;
- }
- case BuiltinOperator_ARG_MIN: {
- auto params = safe_allocator.Allocate<TfLiteArgMinParams>();
- TF_LITE_ENSURE(error_reporter, params != nullptr);
- if (const auto* schema_params = op->builtin_options_as_ArgMinOptions()) {
- TF_LITE_ENSURE_STATUS(ConvertTensorType(schema_params->output_type(),
- ¶ms->output_type,
- error_reporter));
- }
- *builtin_data = params.release();
- return kTfLiteOk;
- }
case BuiltinOperator_TRANSPOSE_CONV: {
auto params = safe_allocator.Allocate<TfLiteTransposeConvParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
@@ -1019,7 +1088,6 @@
return kTfLiteOk;
}
// Below are the ops with no builtin_data structure.
- case BuiltinOperator_ABS:
case BuiltinOperator_BATCH_TO_SPACE_ND:
// TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
// ok for now, since there is no call implementation either.
diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.h b/tensorflow/lite/core/api/flatbuffer_conversions.h
index 78d2aca..a6431aa 100644
--- a/tensorflow/lite/core/api/flatbuffer_conversions.h
+++ b/tensorflow/lite/core/api/flatbuffer_conversions.h
@@ -75,6 +75,22 @@
// removed once we are no longer using ParseOpData for the OpResolver
// implementation in micro.
+TfLiteStatus ParseAbs(const Operator* op, BuiltinOperator op_type,
+ ErrorReporter* error_reporter,
+ BuiltinDataAllocator* allocator, void** builtin_data);
+
+TfLiteStatus ParseAdd(const Operator* op, BuiltinOperator op_type,
+ ErrorReporter* error_reporter,
+ BuiltinDataAllocator* allocator, void** builtin_data);
+
+TfLiteStatus ParseArgMax(const Operator* op, BuiltinOperator op_type,
+ ErrorReporter* error_reporter,
+ BuiltinDataAllocator* allocator, void** builtin_data);
+
+TfLiteStatus ParseArgMin(const Operator* op, BuiltinOperator op_type,
+ ErrorReporter* error_reporter,
+ BuiltinDataAllocator* allocator, void** builtin_data);
+
TfLiteStatus ParseConv2D(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
diff --git a/tensorflow/lite/micro/micro_mutable_op_resolver.h b/tensorflow/lite/micro/micro_mutable_op_resolver.h
index 1b76f44..8c99f77 100644
--- a/tensorflow/lite/micro/micro_mutable_op_resolver.h
+++ b/tensorflow/lite/micro/micro_mutable_op_resolver.h
@@ -108,31 +108,23 @@
// MicroMutableOpResolver object.
TfLiteStatus AddAbs() {
- // TODO(b/149408647): Replace ParseOpData with the operator specific parse
- // function.
return AddBuiltin(BuiltinOperator_ABS, *tflite::ops::micro::Register_ABS(),
- ParseOpData);
+ ParseAbs);
}
TfLiteStatus AddAdd() {
- // TODO(b/149408647): Replace ParseOpData with the operator specific parse
- // function.
return AddBuiltin(BuiltinOperator_ADD, *tflite::ops::micro::Register_ADD(),
- ParseOpData);
+ ParseAdd);
}
TfLiteStatus AddArgMax() {
- // TODO(b/149408647): Replace ParseOpData with the operator specific parse
- // function.
return AddBuiltin(BuiltinOperator_ARG_MAX,
- *tflite::ops::micro::Register_ARG_MAX(), ParseOpData);
+ *tflite::ops::micro::Register_ARG_MAX(), ParseArgMax);
}
TfLiteStatus AddArgMin() {
- // TODO(b/149408647): Replace ParseOpData with the operator specific parse
- // function.
return AddBuiltin(BuiltinOperator_ARG_MIN,
- *tflite::ops::micro::Register_ARG_MIN(), ParseOpData);
+ *tflite::ops::micro::Register_ARG_MIN(), ParseArgMin);
}
TfLiteStatus AddAveragePool2D() {