Small cleanup: Use GetInput/GetOutput helper functions.
PiperOrigin-RevId: 302694846
Change-Id: I6838ab7aeeb2550805554693fc3412b1edaa4688
diff --git a/tensorflow/lite/micro/kernels/dequantize.cc b/tensorflow/lite/micro/kernels/dequantize.cc
index 1583a6f..21c34a4 100644
--- a/tensorflow/lite/micro/kernels/dequantize.cc
+++ b/tensorflow/lite/micro/kernels/dequantize.cc
@@ -33,8 +33,8 @@
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
// TODO(b/140515557): Add cached dequant to improve hybrid model performance.
- TfLiteTensor* input = &context->tensors[node->inputs->data[0]];
- TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
+ const TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* output = GetOutput(context, node, 0);
TF_LITE_ENSURE(context, input->type == kTfLiteUInt8 ||
input->type == kTfLiteInt8 ||
@@ -46,8 +46,8 @@
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- TfLiteTensor* input = &context->tensors[node->inputs->data[0]];
- TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
+ const TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* output = GetOutput(context, node, 0);
tflite::DequantizationParams op_params;
op_params.zero_point = input->params.zero_point;
diff --git a/tensorflow/lite/micro/kernels/pack.cc b/tensorflow/lite/micro/kernels/pack.cc
index e9b7e65..3c47ce8 100644
--- a/tensorflow/lite/micro/kernels/pack.cc
+++ b/tensorflow/lite/micro/kernels/pack.cc
@@ -34,7 +34,7 @@
TfLiteStatus PackImpl(TfLiteContext* context, TfLiteNode* node,
TfLiteTensor* output, int values_count, int axis) {
const int dimensions = output->dims->size;
- const TfLiteTensor* input0 = &context->tensors[node->inputs->data[0]];
+ const TfLiteTensor* input0 = GetInput(context, node, 0);
const TfLiteIntArray* input_dims = input0->dims;
const TfLiteIntArray* output_dims = output->dims;
@@ -59,7 +59,7 @@
T* output_data = GetTensorData<T>(output);
for (int i = 0; i < values_count; ++i) {
- TfLiteTensor* t = &context->tensors[node->inputs->data[i]];
+ const TfLiteTensor* t = GetInput(context, node, i);
const T* input_data = GetTensorData<T>(t);
for (int k = 0; k < outer_size; ++k) {
const T* input_ptr = input_data + copy_size * k;
diff --git a/tensorflow/lite/micro/kernels/quantize.cc b/tensorflow/lite/micro/kernels/quantize.cc
index bd35393..8ad69ce 100644
--- a/tensorflow/lite/micro/kernels/quantize.cc
+++ b/tensorflow/lite/micro/kernels/quantize.cc
@@ -34,8 +34,8 @@
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- TfLiteTensor* input = &context->tensors[node->inputs->data[0]];
- TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
+ const TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* output = GetOutput(context, node, 0);
// TODO(b/128934713): Add support for fixed-point per-channel quantization.
// Currently this only support affine per-layer quantization.
@@ -56,8 +56,8 @@
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- TfLiteTensor* input = &context->tensors[node->inputs->data[0]];
- TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
+ const TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* output = GetOutput(context, node, 0);
tflite::QuantizationParams op_params;
op_params.zero_point = output->params.zero_point;
diff --git a/tensorflow/lite/micro/kernels/reduce.cc b/tensorflow/lite/micro/kernels/reduce.cc
index ec4491b..0705292 100644
--- a/tensorflow/lite/micro/kernels/reduce.cc
+++ b/tensorflow/lite/micro/kernels/reduce.cc
@@ -43,7 +43,7 @@
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
// Validate axis type
- const TfLiteTensor* axis = &context->tensors[node->inputs->data[1]];
+ const TfLiteTensor* axis = GetInput(context, node, 1);
TF_LITE_ENSURE_TYPES_EQ(context, axis->type, kTfLiteInt32);
return kTfLiteOk;
}
@@ -67,9 +67,9 @@
}
TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
- const TfLiteTensor* input = &context->tensors[node->inputs->data[0]];
- const TfLiteTensor* axis = &context->tensors[node->inputs->data[1]];
- TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
+ const TfLiteTensor* input = GetInput(context, node, 0);
+ const TfLiteTensor* axis = GetInput(context, node, 1);
+ TfLiteTensor* output = GetOutput(context, node, 0);
TfLiteReducerParams* params =
reinterpret_cast<TfLiteReducerParams*>(node->builtin_data);
diff --git a/tensorflow/lite/micro/kernels/split.cc b/tensorflow/lite/micro/kernels/split.cc
index d32a88e..6aaa37f 100644
--- a/tensorflow/lite/micro/kernels/split.cc
+++ b/tensorflow/lite/micro/kernels/split.cc
@@ -32,7 +32,7 @@
const TfLiteTensor* input, int axis_value) {
const int output_count = NumOutputs(node);
const TfLiteIntArray* input_dims = input->dims;
- const TfLiteTensor* output0 = &context->tensors[node->outputs->data[0]];
+ const TfLiteTensor* output0 = GetOutput(context, node, 0);
const TfLiteIntArray* output_dims = output0->dims;
const int split_dimensions = input_dims->size;
@@ -57,7 +57,7 @@
const T* input_ptr = GetTensorData<T>(input);
for (int k = 0; k < outer_size; ++k) {
for (int i = 0; i < output_count; ++i) {
- TfLiteTensor* t = &context->tensors[node->outputs->data[i]];
+ TfLiteTensor* t = GetOutput(context, node, i);
T* output_data = GetTensorData<T>(t);
const int copy_size = output_dims->data[axis] * base_inner_size;
T* output_ptr = output_data + k * copy_size;
diff --git a/tensorflow/lite/micro/kernels/unpack.cc b/tensorflow/lite/micro/kernels/unpack.cc
index a189b92..9ca69f7 100644
--- a/tensorflow/lite/micro/kernels/unpack.cc
+++ b/tensorflow/lite/micro/kernels/unpack.cc
@@ -33,7 +33,7 @@
template <typename T>
TfLiteStatus UnpackImpl(TfLiteContext* context, TfLiteNode* node,
const TfLiteTensor* input, int output_count, int axis) {
- const TfLiteTensor* output0 = &context->tensors[node->outputs->data[0]];
+ const TfLiteTensor* output0 = GetOutput(context, node, 0);
const TfLiteIntArray* input_dims = input->dims;
const TfLiteIntArray* output_dims = output0->dims;
const int dimensions = input_dims->size;
@@ -61,7 +61,7 @@
const T* input_data = GetTensorData<T>(input);
for (int i = 0; i < output_count; ++i) {
- TfLiteTensor* t = &context->tensors[node->outputs->data[i]];
+ TfLiteTensor* t = GetOutput(context, node, i);
T* output_data = GetTensorData<T>(t);
for (int k = 0; k < outer_size; ++k) {
T* output_ptr = output_data + copy_size * k;
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/quantize.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/quantize.cc
index caa5677..7846898 100644
--- a/tensorflow/lite/micro/kernels/xtensa_hifimini/quantize.cc
+++ b/tensorflow/lite/micro/kernels/xtensa_hifimini/quantize.cc
@@ -125,7 +125,7 @@
void Free(TfLiteContext* context, void* buffer) {}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
+ TfLiteTensor* output = GetOutput(context, node, 0);
// TODO(b/132070898): Use statically slotted OpData structures until a
// scratch memory API is ready.
@@ -141,8 +141,8 @@
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* op_data = reinterpret_cast<OpData*>(node->user_data);
- TfLiteTensor* input = &context->tensors[node->inputs->data[0]];
- TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
+ const TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* output = GetOutput(context, node, 0);
tflite::QuantizationParams op_params;
op_params.zero_point = output->params.zero_point;
diff --git a/tensorflow/lite/micro/micro_interpreter_test.cc b/tensorflow/lite/micro/micro_interpreter_test.cc
index 0846c38..9517a80 100644
--- a/tensorflow/lite/micro/micro_interpreter_test.cc
+++ b/tensorflow/lite/micro/micro_interpreter_test.cc
@@ -134,11 +134,11 @@
}
static TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) {
- const TfLiteTensor* input = &context->tensors[node->inputs->data[0]];
+ const TfLiteTensor* input = GetInput(context, node, 0);
const int32_t* input_data = input->data.i32;
- const TfLiteTensor* weight = &context->tensors[node->inputs->data[1]];
+ const TfLiteTensor* weight = GetInput(context, node, 1);
const uint8_t* weight_data = weight->data.uint8;
- TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
+ TfLiteTensor* output = GetOutput(context, node, 0);
int32_t* output_data = output->data.i32;
output_data[0] =
0; // Catch output tensor sharing memory with an input tensor