Cleanup reshape_test to use test_helpers.h instead of testing/test_utils.h
PiperOrigin-RevId: 335941197
Change-Id: I50488b7ccb74383545de0dd7220841fa1f4e579d
diff --git a/tensorflow/lite/micro/kernels/reshape_test.cc b/tensorflow/lite/micro/kernels/reshape_test.cc
index 91ecbdc..48d1956 100644
--- a/tensorflow/lite/micro/kernels/reshape_test.cc
+++ b/tensorflow/lite/micro/kernels/reshape_test.cc
@@ -23,7 +23,6 @@
#include "tensorflow/lite/micro/micro_utils.h"
#include "tensorflow/lite/micro/test_helpers.h"
#include "tensorflow/lite/micro/testing/micro_test.h"
-#include "tensorflow/lite/micro/testing/test_utils.h"
namespace tflite {
namespace testing {
@@ -113,22 +112,41 @@
expected_dims_len, expect_failure);
}
-template <typename T = float, TfLiteType tensor_type = kTfLiteFloat32>
-void TestReshape(const int* input_dims_data, const T* input_data,
+void TestReshape(const int* input_dims_data, const float* input_data,
const int* shape_dims_data, const int32_t* shape_data,
- int* output_dims_data, T* output_data,
- const T* expected_output, const size_t expected_output_len,
+ int* output_dims_data, float* output_data,
+ const float* expected_output, const size_t expected_output_len,
const int* expected_dims, const size_t expected_dims_len,
bool expect_failure = false) {
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
TfLiteIntArray* shape_dims = IntArrayFromInts(shape_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
- TfLiteTensor input_tensor =
- CreateTensor<T, tensor_type>(input_data, input_dims);
- TfLiteTensor shape_tensor =
- CreateTensor<int32_t, kTfLiteInt32>(shape_data, shape_dims);
- TfLiteTensor output_tensor =
- CreateTensor<T, tensor_type>(output_data, output_dims);
+ TfLiteTensor input_tensor = CreateFloatTensor(input_data, input_dims);
+ TfLiteTensor shape_tensor = CreateInt32Tensor(shape_data, shape_dims);
+ TfLiteTensor output_tensor = CreateFloatTensor(output_data, output_dims);
+
+ TestReshapeWithShape(&input_tensor, &shape_tensor, &output_tensor,
+ expected_output, expected_output_len, expected_dims,
+ expected_dims_len, expect_failure);
+}
+
+template <typename T>
+void TestReshapeQuantized(const int* input_dims_data, const T* input_data,
+ const int* shape_dims_data, const int32_t* shape_data,
+ int* output_dims_data, T* output_data,
+ const T* expected_output,
+ const size_t expected_output_len,
+ const int* expected_dims,
+ const size_t expected_dims_len,
+ bool expect_failure = false) {
+ TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
+ TfLiteIntArray* shape_dims = IntArrayFromInts(shape_dims_data);
+ TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
+ TfLiteTensor input_tensor = CreateQuantizedTensor(
+ input_data, input_dims, /*scale=*/1.f, /*zero_point=*/0);
+ TfLiteTensor shape_tensor = CreateInt32Tensor(shape_data, shape_dims);
+ TfLiteTensor output_tensor = CreateQuantizedTensor(
+ output_data, output_dims, /*scale=*/1.f, /*zero_point=*/0);
TestReshapeWithShape(&input_tensor, &shape_tensor, &output_tensor,
expected_output, expected_output_len, expected_dims,
@@ -233,11 +251,11 @@
output_dims, output_data_float,
golden_output_float, golden_output_len,
golden_dims, golden_dims_len, false);
- tflite::testing::TestReshape<int8_t, kTfLiteInt8>(
+ tflite::testing::TestReshapeQuantized(
input_dims, input_int8, shape_dims, shape_int32, output_dims,
output_data_int8, golden_output_int8, golden_output_len, golden_dims,
golden_dims_len, false);
- tflite::testing::TestReshape<uint8_t, kTfLiteUInt8>(
+ tflite::testing::TestReshapeQuantized(
input_dims, input_uint8, shape_dims, shape_int32, output_dims,
output_data_uint8, golden_output_uint8, golden_output_len, golden_dims,
golden_dims_len, false);
@@ -265,11 +283,11 @@
output_dims, output_data_float,
golden_output_float, golden_output_len,
golden_dims, golden_dims_len, false);
- tflite::testing::TestReshape<int8_t, kTfLiteInt8>(
+ tflite::testing::TestReshapeQuantized(
input_dims, input_int8, shape_dims, shape_int32, output_dims,
output_data_int8, golden_output_int8, golden_output_len, golden_dims,
golden_dims_len, false);
- tflite::testing::TestReshape<uint8_t, kTfLiteUInt8>(
+ tflite::testing::TestReshapeQuantized(
input_dims, input_uint8, shape_dims, shape_int32, output_dims,
output_data_uint8, golden_output_uint8, golden_output_len, golden_dims,
golden_dims_len, false);
@@ -297,11 +315,11 @@
output_dims, output_data_float,
golden_output_float, golden_output_len,
golden_dims, golden_dims_len, false);
- tflite::testing::TestReshape<int8_t, kTfLiteInt8>(
+ tflite::testing::TestReshapeQuantized(
input_dims, input_int8, shape_dims, shape_int32, output_dims,
output_data_int8, golden_output_int8, golden_output_len, golden_dims,
golden_dims_len, false);
- tflite::testing::TestReshape<uint8_t, kTfLiteUInt8>(
+ tflite::testing::TestReshapeQuantized(
input_dims, input_uint8, shape_dims, shape_int32, output_dims,
output_data_uint8, golden_output_uint8, golden_output_len, golden_dims,
golden_dims_len, false);
@@ -327,8 +345,8 @@
TfLiteIntArray* shape_dims = IntArrayFromInts(shape_dims_data);
const int32_t shape_data[] = {0};
- auto shape_tensor = tflite::testing::CreateTensor<int32_t, kTfLiteInt32>(
- shape_data, shape_dims);
+ auto shape_tensor =
+ tflite::testing::CreateInt32Tensor(shape_data, shape_dims);
const float expected_output_with_shape[] = {};
const int expected_output_with_shape_len = 0;
const float expected_output_no_shape[] = {3};