Refactoring SVDF prior to ceva optimized kernel integration.
Moved prepare and float eval to svdf_common.cc
diff --git a/tensorflow/lite/micro/kernels/svdf.cc b/tensorflow/lite/micro/kernels/svdf.cc
index 9d7c09e..096b0e9 100644
--- a/tensorflow/lite/micro/kernels/svdf.cc
+++ b/tensorflow/lite/micro/kernels/svdf.cc
@@ -31,17 +31,6 @@
namespace tflite {
namespace {
-// Input tensors.
-constexpr int kInputTensor = 0;
-constexpr int kWeightsFeatureTensor = 1;
-constexpr int kWeightsTimeTensor = 2;
-constexpr int kBiasTensor = 3;
-// This is a variable tensor, and will be modified by this op.
-constexpr int kInputActivationStateTensor = 4;
-
-// Output tensor.
-constexpr int kOutputTensor = 0;
-
/**
* This version of SVDF is specific to TFLite Micro. It contains the following
* differences between the TFLite version:
@@ -53,142 +42,9 @@
* resizing.
*/
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- TFLITE_DCHECK(node->builtin_data != nullptr);
-
- const auto* params = static_cast<const TfLiteSVDFParams*>(node->builtin_data);
-
- // Validate Tensor Inputs (dtype depends on quantization):
- // [0] = Input, {2, batch_size, input_size}
- // [1] = Weights Feature, {2, num_filters, input_size}
- // [2] = Weights Time, {2, num_filters, memory_size}
- // [3] = Bias (optional), {1, num_units}
- // [4] = Activation State (variable),
- // {2, batch_size, memory_size * num_filters}
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- TF_LITE_ENSURE(context, input != nullptr);
- const TfLiteTensor* weights_feature =
- GetInput(context, node, kWeightsFeatureTensor);
- TF_LITE_ENSURE(context, weights_feature != nullptr);
- const TfLiteTensor* weights_time =
- GetInput(context, node, kWeightsTimeTensor);
- TF_LITE_ENSURE(context, weights_time != nullptr);
- const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
- const TfLiteTensor* activation_state =
- GetInput(context, node, kInputActivationStateTensor);
- TF_LITE_ENSURE(context, activation_state != nullptr);
-
- // Define input constants based on input tensor definition above:
- const int rank = params->rank;
- const int input_size = input->dims->data[1];
- const int batch_size = input->dims->data[0];
- const int num_filters = weights_feature->dims->data[0];
- TF_LITE_ENSURE_EQ(context, num_filters % rank, 0);
- const int num_units = num_filters / rank;
- const int memory_size = weights_time->dims->data[1];
-
- // Validate Input Tensor:
- TF_LITE_ENSURE(context,
- input->type == kTfLiteFloat32 || input->type == kTfLiteInt8);
- TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2);
-
- // Validate Tensor Output:
- // [0] = float/int8_t, {2, batch_size, num_units}
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- TF_LITE_ENSURE(context, output != nullptr);
- TF_LITE_ENSURE_EQ(context, NumDimensions(output), 2);
- TF_LITE_ENSURE_EQ(context, output->dims->data[0], batch_size);
- TF_LITE_ENSURE_EQ(context, output->dims->data[1], num_units);
-
- // Validate Weights Feature Input Tensor:
- TF_LITE_ENSURE_EQ(context, NumDimensions(weights_feature), 2);
- TF_LITE_ENSURE_EQ(context, weights_feature->dims->data[1], input_size);
-
- // Validate Weights Time Input Tensor:
- TF_LITE_ENSURE_EQ(context, NumDimensions(weights_time), 2);
- TF_LITE_ENSURE_EQ(context, weights_time->dims->data[0], num_filters);
- TF_LITE_ENSURE_EQ(context, weights_time->dims->data[1], memory_size);
-
- // Validate Optional Bias Input Tensor:
- if (bias != nullptr) {
- TF_LITE_ENSURE_EQ(context, bias->dims->data[0], num_units);
- }
-
- // Validate Activation State Input Tensor:
- TF_LITE_ENSURE_EQ(context, NumDimensions(activation_state), 2);
- TF_LITE_ENSURE_EQ(context, activation_state->dims->data[0], batch_size);
- TF_LITE_ENSURE_EQ(context, activation_state->dims->data[1],
- memory_size * num_filters);
- // Since is_variable is not part of TFLiteEvalTensor, check is_variable here.
- TF_LITE_ENSURE_EQ(context, activation_state->is_variable, true);
-
- TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
-
- TFLITE_DCHECK(node->user_data != nullptr);
- OpData* data = static_cast<OpData*>(node->user_data);
-
- if (input->type == kTfLiteInt8) {
- TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteInt8);
- TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteInt16);
- TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteInt16);
- if (bias != nullptr) {
- TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32);
- }
-
- TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8);
-
- const double effective_scale_1 = static_cast<double>(
- input->params.scale * weights_feature->params.scale /
- activation_state->params.scale);
- const double effective_scale_2 =
- static_cast<double>(activation_state->params.scale *
- weights_time->params.scale / output->params.scale);
-
- // TODO(b/162018098): Use TF_LITE_ENSURE_NEAR when it is ready.
- TF_LITE_ENSURE(
- context,
- std::abs(static_cast<double>(bias->params.scale) -
- static_cast<double>(activation_state->params.scale *
- weights_time->params.scale)) < 1e-5);
-
- QuantizeMultiplier(effective_scale_1, &(data->effective_scale_1_a),
- &(data->effective_scale_1_b));
- QuantizeMultiplier(effective_scale_2, &(data->effective_scale_2_a),
- &(data->effective_scale_2_b));
-
- data->input_zero_point = input->params.zero_point;
- data->output_zero_point = output->params.zero_point;
-
- TFLITE_DCHECK(context->RequestScratchBufferInArena != nullptr);
-
- const TfLiteStatus scratch_status = context->RequestScratchBufferInArena(
- context, batch_size * num_filters * sizeof(int32_t),
- &(data->scratch_tensor_index));
- TF_LITE_ENSURE_OK(context, scratch_status);
-
- const TfLiteStatus scratch_output_status =
- context->RequestScratchBufferInArena(
- context, batch_size * num_units * sizeof(int32_t),
- &(data->scratch_output_tensor_index));
- TF_LITE_ENSURE_OK(context, scratch_output_status);
- } else {
- TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteFloat32);
- TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteFloat32);
- TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteFloat32);
- if (bias != nullptr) {
- TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteFloat32);
- }
- TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
-
- TFLITE_DCHECK(context->RequestScratchBufferInArena != nullptr);
- const TfLiteStatus scratch_status = context->RequestScratchBufferInArena(
- context, batch_size * num_filters * sizeof(float),
- &(data->scratch_tensor_index));
- TF_LITE_ENSURE_OK(context, scratch_status);
- }
-
- return kTfLiteOk;
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
+ return context->AllocatePersistentBuffer(context, sizeof(OpData));
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
@@ -241,7 +97,7 @@
TfLiteRegistration Register_SVDF() {
return {/*init=*/Init,
/*free=*/nullptr,
- /*prepare=*/Prepare,
+ /*prepare=*/PrepareSVDF,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
diff --git a/tensorflow/lite/micro/kernels/svdf.h b/tensorflow/lite/micro/kernels/svdf.h
index af07ace..f802f63 100644
--- a/tensorflow/lite/micro/kernels/svdf.h
+++ b/tensorflow/lite/micro/kernels/svdf.h
@@ -35,6 +35,17 @@
int output_zero_point;
};
+// Input tensors.
+constexpr int kInputTensor = 0;
+constexpr int kWeightsFeatureTensor = 1;
+constexpr int kWeightsTimeTensor = 2;
+constexpr int kBiasTensor = 3;
+// This is a variable tensor, and will be modified by this op.
+constexpr int kInputActivationStateTensor = 4;
+
+// Output tensor.
+constexpr int kOutputTensor = 0;
+
// TensorflowLite Micro-specific reference implementation for Integer SVDF.
void EvalIntegerSvdfReference(TfLiteContext* context, TfLiteNode* node,
const TfLiteEvalTensor* input_tensor,
@@ -54,6 +65,8 @@
int scratch_tensor_index, TfLiteEvalTensor* activation_state,
TfLiteEvalTensor* output);
+TfLiteStatus PrepareSVDF(TfLiteContext* context, TfLiteNode* node);
+
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_KERNELS_SVDF_H_
diff --git a/tensorflow/lite/micro/kernels/svdf_common.cc b/tensorflow/lite/micro/kernels/svdf_common.cc
index e876b21..2d64104 100644
--- a/tensorflow/lite/micro/kernels/svdf_common.cc
+++ b/tensorflow/lite/micro/kernels/svdf_common.cc
@@ -309,4 +309,142 @@
bias_ptr, params->activation, state_ptr, scratch_ptr, output_ptr);
}
+TfLiteStatus PrepareSVDF(TfLiteContext* context, TfLiteNode* node) {
+ TFLITE_DCHECK(node->builtin_data != nullptr);
+
+ const auto* params = static_cast<const TfLiteSVDFParams*>(node->builtin_data);
+
+ // Validate Tensor Inputs (dtype depends on quantization):
+ // [0] = Input, {2, batch_size, input_size}
+ // [1] = Weights Feature, {2, num_filters, input_size}
+ // [2] = Weights Time, {2, num_filters, memory_size}
+ // [3] = Bias (optional), {1, num_units}
+ // [4] = Activation State (variable),
+ // {2, batch_size, memory_size * num_filters}
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TF_LITE_ENSURE(context, input != nullptr);
+ const TfLiteTensor* weights_feature =
+ GetInput(context, node, kWeightsFeatureTensor);
+ TF_LITE_ENSURE(context, weights_feature != nullptr);
+ const TfLiteTensor* weights_time =
+ GetInput(context, node, kWeightsTimeTensor);
+ TF_LITE_ENSURE(context, weights_time != nullptr);
+ const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
+ const TfLiteTensor* activation_state =
+ GetInput(context, node, kInputActivationStateTensor);
+ TF_LITE_ENSURE(context, activation_state != nullptr);
+
+ // Define input constants based on input tensor definition above:
+ const int rank = params->rank;
+ const int input_size = input->dims->data[1];
+ const int batch_size = input->dims->data[0];
+ const int num_filters = weights_feature->dims->data[0];
+ TF_LITE_ENSURE_EQ(context, num_filters % rank, 0);
+ const int num_units = num_filters / rank;
+ const int memory_size = weights_time->dims->data[1];
+
+ // Validate Input Tensor:
+ TF_LITE_ENSURE(context,
+ input->type == kTfLiteFloat32 || input->type == kTfLiteInt8);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2);
+
+ // Validate Tensor Output:
+ // [0] = float/int8_t, {2, batch_size, num_units}
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ TF_LITE_ENSURE(context, output != nullptr);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(output), 2);
+ TF_LITE_ENSURE_EQ(context, output->dims->data[0], batch_size);
+ TF_LITE_ENSURE_EQ(context, output->dims->data[1], num_units);
+
+ // Validate Weights Feature Input Tensor:
+ TF_LITE_ENSURE_EQ(context, NumDimensions(weights_feature), 2);
+ TF_LITE_ENSURE_EQ(context, weights_feature->dims->data[1], input_size);
+
+ // Validate Weights Time Input Tensor:
+ TF_LITE_ENSURE_EQ(context, NumDimensions(weights_time), 2);
+ TF_LITE_ENSURE_EQ(context, weights_time->dims->data[0], num_filters);
+ TF_LITE_ENSURE_EQ(context, weights_time->dims->data[1], memory_size);
+
+ // Validate Optional Bias Input Tensor:
+ if (bias != nullptr) {
+ TF_LITE_ENSURE_EQ(context, bias->dims->data[0], num_units);
+ }
+
+ // Validate Activation State Input Tensor:
+ TF_LITE_ENSURE_EQ(context, NumDimensions(activation_state), 2);
+ TF_LITE_ENSURE_EQ(context, activation_state->dims->data[0], batch_size);
+ TF_LITE_ENSURE_EQ(context, activation_state->dims->data[1],
+ memory_size * num_filters);
+ // Since is_variable is not part of TFLiteEvalTensor, check is_variable here.
+ TF_LITE_ENSURE_EQ(context, activation_state->is_variable, true);
+
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
+
+ TFLITE_DCHECK(node->user_data != nullptr);
+ OpData* data = static_cast<OpData*>(node->user_data);
+
+ if (input->type == kTfLiteInt8) {
+ TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteInt8);
+ TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteInt16);
+ TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteInt16);
+ if (bias != nullptr) {
+ TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32);
+ }
+
+ TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8);
+
+ const double effective_scale_1 = static_cast<double>(
+ input->params.scale * weights_feature->params.scale /
+ activation_state->params.scale);
+ const double effective_scale_2 =
+ static_cast<double>(activation_state->params.scale *
+ weights_time->params.scale / output->params.scale);
+
+ // TODO(b/162018098): Use TF_LITE_ENSURE_NEAR when it is ready.
+ TF_LITE_ENSURE(
+ context,
+ std::abs(static_cast<double>(bias->params.scale) -
+ static_cast<double>(activation_state->params.scale *
+ weights_time->params.scale)) < 1e-5);
+
+ QuantizeMultiplier(effective_scale_1, &(data->effective_scale_1_a),
+ &(data->effective_scale_1_b));
+ QuantizeMultiplier(effective_scale_2, &(data->effective_scale_2_a),
+ &(data->effective_scale_2_b));
+
+ data->input_zero_point = input->params.zero_point;
+ data->output_zero_point = output->params.zero_point;
+
+ TFLITE_DCHECK(context->RequestScratchBufferInArena != nullptr);
+
+ const TfLiteStatus scratch_status = context->RequestScratchBufferInArena(
+ context, batch_size * num_filters * sizeof(int32_t),
+ &(data->scratch_tensor_index));
+ TF_LITE_ENSURE_OK(context, scratch_status);
+
+ const TfLiteStatus scratch_output_status =
+ context->RequestScratchBufferInArena(
+ context, batch_size * num_units * sizeof(int32_t),
+ &(data->scratch_output_tensor_index));
+ TF_LITE_ENSURE_OK(context, scratch_output_status);
+ } else {
+ TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteFloat32);
+ TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteFloat32);
+ TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteFloat32);
+ if (bias != nullptr) {
+ TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteFloat32);
+ }
+ TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
+
+ TFLITE_DCHECK(context->RequestScratchBufferInArena != nullptr);
+ const TfLiteStatus scratch_status = context->RequestScratchBufferInArena(
+ context, batch_size * num_filters * sizeof(float),
+ &(data->scratch_tensor_index));
+ TF_LITE_ENSURE_OK(context, scratch_status);
+ }
+
+ return kTfLiteOk;
+}
+
} // namespace tflite