Cleanup reshape_test to use test_helpers.h instead of testing/test_utils.h

PiperOrigin-RevId: 335941197
Change-Id: I50488b7ccb74383545de0dd7220841fa1f4e579d
diff --git a/tensorflow/lite/micro/kernels/reshape_test.cc b/tensorflow/lite/micro/kernels/reshape_test.cc
index 91ecbdc..48d1956 100644
--- a/tensorflow/lite/micro/kernels/reshape_test.cc
+++ b/tensorflow/lite/micro/kernels/reshape_test.cc
@@ -23,7 +23,6 @@
 #include "tensorflow/lite/micro/micro_utils.h"
 #include "tensorflow/lite/micro/test_helpers.h"
 #include "tensorflow/lite/micro/testing/micro_test.h"
-#include "tensorflow/lite/micro/testing/test_utils.h"
 
 namespace tflite {
 namespace testing {
@@ -113,22 +112,41 @@
                          expected_dims_len, expect_failure);
 }
 
-template <typename T = float, TfLiteType tensor_type = kTfLiteFloat32>
-void TestReshape(const int* input_dims_data, const T* input_data,
+void TestReshape(const int* input_dims_data, const float* input_data,
                  const int* shape_dims_data, const int32_t* shape_data,
-                 int* output_dims_data, T* output_data,
-                 const T* expected_output, const size_t expected_output_len,
+                 int* output_dims_data, float* output_data,
+                 const float* expected_output, const size_t expected_output_len,
                  const int* expected_dims, const size_t expected_dims_len,
                  bool expect_failure = false) {
   TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
   TfLiteIntArray* shape_dims = IntArrayFromInts(shape_dims_data);
   TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
-  TfLiteTensor input_tensor =
-      CreateTensor<T, tensor_type>(input_data, input_dims);
-  TfLiteTensor shape_tensor =
-      CreateTensor<int32_t, kTfLiteInt32>(shape_data, shape_dims);
-  TfLiteTensor output_tensor =
-      CreateTensor<T, tensor_type>(output_data, output_dims);
+  TfLiteTensor input_tensor = CreateFloatTensor(input_data, input_dims);
+  TfLiteTensor shape_tensor = CreateInt32Tensor(shape_data, shape_dims);
+  TfLiteTensor output_tensor = CreateFloatTensor(output_data, output_dims);
+
+  TestReshapeWithShape(&input_tensor, &shape_tensor, &output_tensor,
+                       expected_output, expected_output_len, expected_dims,
+                       expected_dims_len, expect_failure);
+}
+
+template <typename T>
+void TestReshapeQuantized(const int* input_dims_data, const T* input_data,
+                          const int* shape_dims_data, const int32_t* shape_data,
+                          int* output_dims_data, T* output_data,
+                          const T* expected_output,
+                          const size_t expected_output_len,
+                          const int* expected_dims,
+                          const size_t expected_dims_len,
+                          bool expect_failure = false) {
+  TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
+  TfLiteIntArray* shape_dims = IntArrayFromInts(shape_dims_data);
+  TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
+  TfLiteTensor input_tensor = CreateQuantizedTensor(
+      input_data, input_dims, /*scale=*/1.f, /*zero_point=*/0);
+  TfLiteTensor shape_tensor = CreateInt32Tensor(shape_data, shape_dims);
+  TfLiteTensor output_tensor = CreateQuantizedTensor(
+      output_data, output_dims, /*scale=*/1.f, /*zero_point=*/0);
 
   TestReshapeWithShape(&input_tensor, &shape_tensor, &output_tensor,
                        expected_output, expected_output_len, expected_dims,
@@ -233,11 +251,11 @@
                                output_dims, output_data_float,
                                golden_output_float, golden_output_len,
                                golden_dims, golden_dims_len, false);
-  tflite::testing::TestReshape<int8_t, kTfLiteInt8>(
+  tflite::testing::TestReshapeQuantized(
       input_dims, input_int8, shape_dims, shape_int32, output_dims,
       output_data_int8, golden_output_int8, golden_output_len, golden_dims,
       golden_dims_len, false);
-  tflite::testing::TestReshape<uint8_t, kTfLiteUInt8>(
+  tflite::testing::TestReshapeQuantized(
       input_dims, input_uint8, shape_dims, shape_int32, output_dims,
       output_data_uint8, golden_output_uint8, golden_output_len, golden_dims,
       golden_dims_len, false);
@@ -265,11 +283,11 @@
                                output_dims, output_data_float,
                                golden_output_float, golden_output_len,
                                golden_dims, golden_dims_len, false);
-  tflite::testing::TestReshape<int8_t, kTfLiteInt8>(
+  tflite::testing::TestReshapeQuantized(
       input_dims, input_int8, shape_dims, shape_int32, output_dims,
       output_data_int8, golden_output_int8, golden_output_len, golden_dims,
       golden_dims_len, false);
-  tflite::testing::TestReshape<uint8_t, kTfLiteUInt8>(
+  tflite::testing::TestReshapeQuantized(
       input_dims, input_uint8, shape_dims, shape_int32, output_dims,
       output_data_uint8, golden_output_uint8, golden_output_len, golden_dims,
       golden_dims_len, false);
@@ -297,11 +315,11 @@
                                output_dims, output_data_float,
                                golden_output_float, golden_output_len,
                                golden_dims, golden_dims_len, false);
-  tflite::testing::TestReshape<int8_t, kTfLiteInt8>(
+  tflite::testing::TestReshapeQuantized(
       input_dims, input_int8, shape_dims, shape_int32, output_dims,
       output_data_int8, golden_output_int8, golden_output_len, golden_dims,
       golden_dims_len, false);
-  tflite::testing::TestReshape<uint8_t, kTfLiteUInt8>(
+  tflite::testing::TestReshapeQuantized(
       input_dims, input_uint8, shape_dims, shape_int32, output_dims,
       output_data_uint8, golden_output_uint8, golden_output_len, golden_dims,
       golden_dims_len, false);
@@ -327,8 +345,8 @@
   TfLiteIntArray* shape_dims = IntArrayFromInts(shape_dims_data);
 
   const int32_t shape_data[] = {0};
-  auto shape_tensor = tflite::testing::CreateTensor<int32_t, kTfLiteInt32>(
-      shape_data, shape_dims);
+  auto shape_tensor =
+      tflite::testing::CreateInt32Tensor(shape_data, shape_dims);
   const float expected_output_with_shape[] = {};
   const int expected_output_with_shape_len = 0;
   const float expected_output_no_shape[] = {3};