Refactoring SVDF prior to ceva optimized kernel integration.
Moved prepare and float eval to svdf_common.cc
diff --git a/tensorflow/lite/micro/kernels/svdf.cc b/tensorflow/lite/micro/kernels/svdf.cc
index 9d7c09e..096b0e9 100644
--- a/tensorflow/lite/micro/kernels/svdf.cc
+++ b/tensorflow/lite/micro/kernels/svdf.cc
@@ -31,17 +31,6 @@
 namespace tflite {
 namespace {
 
-// Input tensors.
-constexpr int kInputTensor = 0;
-constexpr int kWeightsFeatureTensor = 1;
-constexpr int kWeightsTimeTensor = 2;
-constexpr int kBiasTensor = 3;
-// This is a variable tensor, and will be modified by this op.
-constexpr int kInputActivationStateTensor = 4;
-
-// Output tensor.
-constexpr int kOutputTensor = 0;
-
 /**
  * This version of SVDF is specific to TFLite Micro. It contains the following
  * differences between the TFLite version:
@@ -53,142 +42,9 @@
  * resizing.
  */
 
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
-  TFLITE_DCHECK(node->builtin_data != nullptr);
-
-  const auto* params = static_cast<const TfLiteSVDFParams*>(node->builtin_data);
-
-  // Validate Tensor Inputs (dtype depends on quantization):
-  // [0] = Input, {2, batch_size, input_size}
-  // [1] = Weights Feature, {2, num_filters, input_size}
-  // [2] = Weights Time, {2, num_filters, memory_size}
-  // [3] = Bias (optional), {1, num_units}
-  // [4] = Activation State (variable),
-  //         {2, batch_size, memory_size * num_filters}
-  const TfLiteTensor* input = GetInput(context, node, kInputTensor);
-  TF_LITE_ENSURE(context, input != nullptr);
-  const TfLiteTensor* weights_feature =
-      GetInput(context, node, kWeightsFeatureTensor);
-  TF_LITE_ENSURE(context, weights_feature != nullptr);
-  const TfLiteTensor* weights_time =
-      GetInput(context, node, kWeightsTimeTensor);
-  TF_LITE_ENSURE(context, weights_time != nullptr);
-  const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
-  const TfLiteTensor* activation_state =
-      GetInput(context, node, kInputActivationStateTensor);
-  TF_LITE_ENSURE(context, activation_state != nullptr);
-
-  // Define input constants based on input tensor definition above:
-  const int rank = params->rank;
-  const int input_size = input->dims->data[1];
-  const int batch_size = input->dims->data[0];
-  const int num_filters = weights_feature->dims->data[0];
-  TF_LITE_ENSURE_EQ(context, num_filters % rank, 0);
-  const int num_units = num_filters / rank;
-  const int memory_size = weights_time->dims->data[1];
-
-  // Validate Input Tensor:
-  TF_LITE_ENSURE(context,
-                 input->type == kTfLiteFloat32 || input->type == kTfLiteInt8);
-  TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2);
-
-  // Validate Tensor Output:
-  // [0] = float/int8_t, {2, batch_size, num_units}
-  TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
-  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-  TF_LITE_ENSURE(context, output != nullptr);
-  TF_LITE_ENSURE_EQ(context, NumDimensions(output), 2);
-  TF_LITE_ENSURE_EQ(context, output->dims->data[0], batch_size);
-  TF_LITE_ENSURE_EQ(context, output->dims->data[1], num_units);
-
-  // Validate Weights Feature Input Tensor:
-  TF_LITE_ENSURE_EQ(context, NumDimensions(weights_feature), 2);
-  TF_LITE_ENSURE_EQ(context, weights_feature->dims->data[1], input_size);
-
-  // Validate Weights Time Input Tensor:
-  TF_LITE_ENSURE_EQ(context, NumDimensions(weights_time), 2);
-  TF_LITE_ENSURE_EQ(context, weights_time->dims->data[0], num_filters);
-  TF_LITE_ENSURE_EQ(context, weights_time->dims->data[1], memory_size);
-
-  // Validate Optional Bias Input Tensor:
-  if (bias != nullptr) {
-    TF_LITE_ENSURE_EQ(context, bias->dims->data[0], num_units);
-  }
-
-  // Validate Activation State Input Tensor:
-  TF_LITE_ENSURE_EQ(context, NumDimensions(activation_state), 2);
-  TF_LITE_ENSURE_EQ(context, activation_state->dims->data[0], batch_size);
-  TF_LITE_ENSURE_EQ(context, activation_state->dims->data[1],
-                    memory_size * num_filters);
-  // Since is_variable is not part of TFLiteEvalTensor, check is_variable here.
-  TF_LITE_ENSURE_EQ(context, activation_state->is_variable, true);
-
-  TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
-
-  TFLITE_DCHECK(node->user_data != nullptr);
-  OpData* data = static_cast<OpData*>(node->user_data);
-
-  if (input->type == kTfLiteInt8) {
-    TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteInt8);
-    TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteInt16);
-    TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteInt16);
-    if (bias != nullptr) {
-      TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32);
-    }
-
-    TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8);
-
-    const double effective_scale_1 = static_cast<double>(
-        input->params.scale * weights_feature->params.scale /
-        activation_state->params.scale);
-    const double effective_scale_2 =
-        static_cast<double>(activation_state->params.scale *
-                            weights_time->params.scale / output->params.scale);
-
-    // TODO(b/162018098): Use TF_LITE_ENSURE_NEAR when it is ready.
-    TF_LITE_ENSURE(
-        context,
-        std::abs(static_cast<double>(bias->params.scale) -
-                 static_cast<double>(activation_state->params.scale *
-                                     weights_time->params.scale)) < 1e-5);
-
-    QuantizeMultiplier(effective_scale_1, &(data->effective_scale_1_a),
-                       &(data->effective_scale_1_b));
-    QuantizeMultiplier(effective_scale_2, &(data->effective_scale_2_a),
-                       &(data->effective_scale_2_b));
-
-    data->input_zero_point = input->params.zero_point;
-    data->output_zero_point = output->params.zero_point;
-
-    TFLITE_DCHECK(context->RequestScratchBufferInArena != nullptr);
-
-    const TfLiteStatus scratch_status = context->RequestScratchBufferInArena(
-        context, batch_size * num_filters * sizeof(int32_t),
-        &(data->scratch_tensor_index));
-    TF_LITE_ENSURE_OK(context, scratch_status);
-
-    const TfLiteStatus scratch_output_status =
-        context->RequestScratchBufferInArena(
-            context, batch_size * num_units * sizeof(int32_t),
-            &(data->scratch_output_tensor_index));
-    TF_LITE_ENSURE_OK(context, scratch_output_status);
-  } else {
-    TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteFloat32);
-    TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteFloat32);
-    TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteFloat32);
-    if (bias != nullptr) {
-      TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteFloat32);
-    }
-    TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
-
-    TFLITE_DCHECK(context->RequestScratchBufferInArena != nullptr);
-    const TfLiteStatus scratch_status = context->RequestScratchBufferInArena(
-        context, batch_size * num_filters * sizeof(float),
-        &(data->scratch_tensor_index));
-    TF_LITE_ENSURE_OK(context, scratch_status);
-  }
-
-  return kTfLiteOk;
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+  TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
+  return context->AllocatePersistentBuffer(context, sizeof(OpData));
 }
 
 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
@@ -241,7 +97,7 @@
 TfLiteRegistration Register_SVDF() {
   return {/*init=*/Init,
           /*free=*/nullptr,
-          /*prepare=*/Prepare,
+          /*prepare=*/PrepareSVDF,
           /*invoke=*/Eval,
           /*profiling_string=*/nullptr,
           /*builtin_code=*/0,
diff --git a/tensorflow/lite/micro/kernels/svdf.h b/tensorflow/lite/micro/kernels/svdf.h
index af07ace..f802f63 100644
--- a/tensorflow/lite/micro/kernels/svdf.h
+++ b/tensorflow/lite/micro/kernels/svdf.h
@@ -35,6 +35,17 @@
   int output_zero_point;
 };
 
+// Input tensors.
+constexpr int kInputTensor = 0;
+constexpr int kWeightsFeatureTensor = 1;
+constexpr int kWeightsTimeTensor = 2;
+constexpr int kBiasTensor = 3;
+// This is a variable tensor, and will be modified by this op.
+constexpr int kInputActivationStateTensor = 4;
+
+// Output tensor.
+constexpr int kOutputTensor = 0;
+
 // TensorflowLite Micro-specific reference implementation for Integer SVDF.
 void EvalIntegerSvdfReference(TfLiteContext* context, TfLiteNode* node,
                               const TfLiteEvalTensor* input_tensor,
@@ -54,6 +65,8 @@
                    int scratch_tensor_index, TfLiteEvalTensor* activation_state,
                    TfLiteEvalTensor* output);
 
+TfLiteStatus PrepareSVDF(TfLiteContext* context, TfLiteNode* node);
+
 }  // namespace tflite
 
 #endif  // TENSORFLOW_LITE_MICRO_KERNELS_SVDF_H_
diff --git a/tensorflow/lite/micro/kernels/svdf_common.cc b/tensorflow/lite/micro/kernels/svdf_common.cc
index e876b21..2d64104 100644
--- a/tensorflow/lite/micro/kernels/svdf_common.cc
+++ b/tensorflow/lite/micro/kernels/svdf_common.cc
@@ -309,4 +309,142 @@
       bias_ptr, params->activation, state_ptr, scratch_ptr, output_ptr);
 }
 
+TfLiteStatus PrepareSVDF(TfLiteContext* context, TfLiteNode* node) {
+  TFLITE_DCHECK(node->builtin_data != nullptr);
+
+  const auto* params = static_cast<const TfLiteSVDFParams*>(node->builtin_data);
+
+  // Validate Tensor Inputs (dtype depends on quantization):
+  // [0] = Input, {2, batch_size, input_size}
+  // [1] = Weights Feature, {2, num_filters, input_size}
+  // [2] = Weights Time, {2, num_filters, memory_size}
+  // [3] = Bias (optional), {1, num_units}
+  // [4] = Activation State (variable),
+  //         {2, batch_size, memory_size * num_filters}
+  const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+  TF_LITE_ENSURE(context, input != nullptr);
+  const TfLiteTensor* weights_feature =
+      GetInput(context, node, kWeightsFeatureTensor);
+  TF_LITE_ENSURE(context, weights_feature != nullptr);
+  const TfLiteTensor* weights_time =
+      GetInput(context, node, kWeightsTimeTensor);
+  TF_LITE_ENSURE(context, weights_time != nullptr);
+  const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
+  const TfLiteTensor* activation_state =
+      GetInput(context, node, kInputActivationStateTensor);
+  TF_LITE_ENSURE(context, activation_state != nullptr);
+
+  // Define input constants based on input tensor definition above:
+  const int rank = params->rank;
+  const int input_size = input->dims->data[1];
+  const int batch_size = input->dims->data[0];
+  const int num_filters = weights_feature->dims->data[0];
+  TF_LITE_ENSURE_EQ(context, num_filters % rank, 0);
+  const int num_units = num_filters / rank;
+  const int memory_size = weights_time->dims->data[1];
+
+  // Validate Input Tensor:
+  TF_LITE_ENSURE(context,
+                 input->type == kTfLiteFloat32 || input->type == kTfLiteInt8);
+  TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2);
+
+  // Validate Tensor Output:
+  // [0] = float/int8_t, {2, batch_size, num_units}
+  TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
+  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+  TF_LITE_ENSURE(context, output != nullptr);
+  TF_LITE_ENSURE_EQ(context, NumDimensions(output), 2);
+  TF_LITE_ENSURE_EQ(context, output->dims->data[0], batch_size);
+  TF_LITE_ENSURE_EQ(context, output->dims->data[1], num_units);
+
+  // Validate Weights Feature Input Tensor:
+  TF_LITE_ENSURE_EQ(context, NumDimensions(weights_feature), 2);
+  TF_LITE_ENSURE_EQ(context, weights_feature->dims->data[1], input_size);
+
+  // Validate Weights Time Input Tensor:
+  TF_LITE_ENSURE_EQ(context, NumDimensions(weights_time), 2);
+  TF_LITE_ENSURE_EQ(context, weights_time->dims->data[0], num_filters);
+  TF_LITE_ENSURE_EQ(context, weights_time->dims->data[1], memory_size);
+
+  // Validate Optional Bias Input Tensor:
+  if (bias != nullptr) {
+    TF_LITE_ENSURE_EQ(context, bias->dims->data[0], num_units);
+  }
+
+  // Validate Activation State Input Tensor:
+  TF_LITE_ENSURE_EQ(context, NumDimensions(activation_state), 2);
+  TF_LITE_ENSURE_EQ(context, activation_state->dims->data[0], batch_size);
+  TF_LITE_ENSURE_EQ(context, activation_state->dims->data[1],
+                    memory_size * num_filters);
+  // Since is_variable is not part of TFLiteEvalTensor, check is_variable here.
+  TF_LITE_ENSURE_EQ(context, activation_state->is_variable, true);
+
+  TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
+
+  TFLITE_DCHECK(node->user_data != nullptr);
+  OpData* data = static_cast<OpData*>(node->user_data);
+
+  if (input->type == kTfLiteInt8) {
+    TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteInt8);
+    TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteInt16);
+    TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteInt16);
+    if (bias != nullptr) {
+      TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32);
+    }
+
+    TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8);
+
+    const double effective_scale_1 = static_cast<double>(
+        input->params.scale * weights_feature->params.scale /
+        activation_state->params.scale);
+    const double effective_scale_2 =
+        static_cast<double>(activation_state->params.scale *
+                            weights_time->params.scale / output->params.scale);
+
+    // TODO(b/162018098): Use TF_LITE_ENSURE_NEAR when it is ready.
+    TF_LITE_ENSURE(
+        context,
+        std::abs(static_cast<double>(bias->params.scale) -
+                 static_cast<double>(activation_state->params.scale *
+                                     weights_time->params.scale)) < 1e-5);
+
+    QuantizeMultiplier(effective_scale_1, &(data->effective_scale_1_a),
+                       &(data->effective_scale_1_b));
+    QuantizeMultiplier(effective_scale_2, &(data->effective_scale_2_a),
+                       &(data->effective_scale_2_b));
+
+    data->input_zero_point = input->params.zero_point;
+    data->output_zero_point = output->params.zero_point;
+
+    TFLITE_DCHECK(context->RequestScratchBufferInArena != nullptr);
+
+    const TfLiteStatus scratch_status = context->RequestScratchBufferInArena(
+        context, batch_size * num_filters * sizeof(int32_t),
+        &(data->scratch_tensor_index));
+    TF_LITE_ENSURE_OK(context, scratch_status);
+
+    const TfLiteStatus scratch_output_status =
+        context->RequestScratchBufferInArena(
+            context, batch_size * num_units * sizeof(int32_t),
+            &(data->scratch_output_tensor_index));
+    TF_LITE_ENSURE_OK(context, scratch_output_status);
+  } else {
+    TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteFloat32);
+    TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteFloat32);
+    TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteFloat32);
+    if (bias != nullptr) {
+      TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteFloat32);
+    }
+    TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
+
+    TFLITE_DCHECK(context->RequestScratchBufferInArena != nullptr);
+    const TfLiteStatus scratch_status = context->RequestScratchBufferInArena(
+        context, batch_size * num_filters * sizeof(float),
+        &(data->scratch_tensor_index));
+    TF_LITE_ENSURE_OK(context, scratch_status);
+  }
+
+  return kTfLiteOk;
+}
+
 }  // namespace tflite