blob: d0babb4b98d7de36e09102c07a09a1dd29a62dc2 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/reference/pooling.h"
#include "cmsis/CMSIS/NN/Include/arm_nnfunctions.h"
#include "flatbuffers/base.h" // from @flatbuffers
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/padding.h"
namespace tflite {
namespace ops {
namespace micro {
namespace pooling {
namespace {
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;
struct OpData {
TfLitePaddingValues padding;
// Index to buffer for optimizations if applicable.
int buffer_idx;
int32_t activation_min;
int32_t activation_max;
};
TfLiteStatus CalculateOpData(TfLiteContext* context,
const TfLitePoolParams* params,
const TfLiteTensor* input, TfLiteTensor* output,
OpData* data) {
// input: batch, height, width, channel
int height = SizeOfDimension(input, 1);
int width = SizeOfDimension(input, 2);
int out_height, out_width;
data->padding = ComputePaddingHeightWidth(
params->stride_height, params->stride_width,
/*dilation_rate_height=*/1,
/*dilation_rate_width=*/1, height, width, params->filter_height,
params->filter_width, params->padding, &out_height, &out_width);
if (input->type != kTfLiteFloat32) {
TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
context, params->activation, output, &data->activation_min,
&data->activation_max));
TFLITE_DCHECK_LE(data->activation_min, data->activation_max);
}
// Set buffer index to a reset value
data->buffer_idx = -1;
return kTfLiteOk;
}
void AverageEvalFloat(const TfLiteContext* context, const TfLiteNode* node,
const TfLitePoolParams* params, const OpData& data,
const TfLiteTensor* input, TfLiteTensor* output) {
float activation_min, activation_max;
CalculateActivationRange(params->activation, &activation_min,
&activation_max);
PoolParams op_params;
op_params.stride_height = params->stride_height;
op_params.stride_width = params->stride_width;
op_params.filter_height = params->filter_height;
op_params.filter_width = params->filter_width;
op_params.padding_values.height = data.padding.height;
op_params.padding_values.width = data.padding.width;
op_params.float_activation_min = activation_min;
op_params.float_activation_max = activation_max;
reference_ops::AveragePool(
op_params, GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(output), GetTensorData<float>(output));
}
void AverageEvalQuantized(TfLiteContext* context, const TfLiteNode* node,
const TfLitePoolParams* params, const OpData& data,
const TfLiteTensor* input, TfLiteTensor* output) {
TFLITE_DCHECK(input->type == kTfLiteUInt8 || input->type == kTfLiteInt8);
PoolParams op_params;
op_params.stride_height = params->stride_height;
op_params.stride_width = params->stride_width;
op_params.filter_height = params->filter_height;
op_params.filter_width = params->filter_width;
op_params.padding_values.height = data.padding.height;
op_params.padding_values.width = data.padding.width;
op_params.quantized_activation_min = data.activation_min;
op_params.quantized_activation_max = data.activation_max;
if (input->type == kTfLiteUInt8) {
reference_ops::AveragePool(
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
GetTensorShape(output), GetTensorData<uint8_t>(output));
} else {
RuntimeShape input_shape = GetTensorShape(input);
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
RuntimeShape output_shape = GetTensorShape(output);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
const int depth = MatchingDim(input_shape, 3, output_shape, 3);
cmsis_nn_dims input_dims;
input_dims.n = 1;
input_dims.h = input_shape.Dims(1);
input_dims.w = input_shape.Dims(2);
input_dims.c = depth;
cmsis_nn_dims output_dims;
output_dims.n = 1;
output_dims.h = output_shape.Dims(1);
output_dims.w = output_shape.Dims(2);
output_dims.c = depth;
cmsis_nn_pool_params pool_params;
pool_params.stride.h = params->stride_height;
pool_params.stride.w = params->stride_width;
pool_params.padding.h = data.padding.height;
pool_params.padding.w = data.padding.width;
pool_params.activation.min = data.activation_min;
pool_params.activation.max = data.activation_max;
cmsis_nn_dims filter_dims;
filter_dims.n = 1;
filter_dims.h = params->filter_height;
filter_dims.w = params->filter_width;
filter_dims.c = 1;
cmsis_nn_context ctx;
ctx.buf = nullptr;
ctx.size = 0;
if (data.buffer_idx > -1) {
ctx.buf = context->GetScratchBuffer(context, data.buffer_idx);
}
TFLITE_DCHECK_EQ(
arm_avgpool_s8(&ctx, &pool_params, &input_dims,
GetTensorData<int8_t>(input), &filter_dims, &output_dims,
GetTensorData<int8_t>(output)),
ARM_MATH_SUCCESS);
}
}
void MaxEvalFloat(TfLiteContext* context, TfLiteNode* node,
TfLitePoolParams* params, const OpData& data,
TfLiteTensor* input, TfLiteTensor* output) {
float activation_min, activation_max;
CalculateActivationRange(params->activation, &activation_min,
&activation_max);
tflite::PoolParams op_params;
op_params.stride_height = params->stride_height;
op_params.stride_width = params->stride_width;
op_params.filter_height = params->filter_height;
op_params.filter_width = params->filter_width;
op_params.padding_values.height = data.padding.height;
op_params.padding_values.width = data.padding.width;
op_params.float_activation_min = activation_min;
op_params.float_activation_max = activation_max;
reference_ops::MaxPool(op_params, GetTensorShape(input),
GetTensorData<float>(input), GetTensorShape(output),
GetTensorData<float>(output));
}
void MaxEvalQuantizedUInt8(TfLiteContext* context, TfLiteNode* node,
TfLitePoolParams* params, const OpData& data,
TfLiteTensor* input, TfLiteTensor* output) {
tflite::PoolParams op_params;
op_params.stride_height = params->stride_height;
op_params.stride_width = params->stride_width;
op_params.filter_height = params->filter_height;
op_params.filter_width = params->filter_width;
op_params.padding_values.height = data.padding.height;
op_params.padding_values.width = data.padding.width;
op_params.quantized_activation_min = data.activation_min;
op_params.quantized_activation_max = data.activation_max;
reference_ops::MaxPool(op_params, GetTensorShape(input),
GetTensorData<uint8_t>(input), GetTensorShape(output),
GetTensorData<uint8_t>(output));
}
TfLiteStatus MaxEvalInt8(TfLiteContext* context, const TfLiteNode* node,
const TfLitePoolParams* params, const OpData& data,
TfLiteTensor* input, TfLiteTensor* output) {
RuntimeShape input_shape = GetTensorShape(input);
RuntimeShape output_shape = GetTensorShape(output);
const int depth = MatchingDim(input_shape, 3, output_shape, 3);
cmsis_nn_dims input_dims;
input_dims.n = 1;
input_dims.h = input_shape.Dims(1);
input_dims.w = input_shape.Dims(2);
input_dims.c = depth;
cmsis_nn_dims output_dims;
output_dims.n = 1;
output_dims.h = output_shape.Dims(1);
output_dims.w = output_shape.Dims(2);
output_dims.c = depth;
cmsis_nn_pool_params pool_params;
pool_params.stride.h = params->stride_height;
pool_params.stride.w = params->stride_width;
pool_params.padding.h = data.padding.height;
pool_params.padding.w = data.padding.width;
pool_params.activation.min = data.activation_min;
pool_params.activation.max = data.activation_max;
cmsis_nn_dims filter_dims;
filter_dims.n = 1;
filter_dims.h = params->filter_height;
filter_dims.w = params->filter_width;
filter_dims.c = 1;
cmsis_nn_context ctx;
ctx.buf = nullptr;
ctx.size = 0;
if (data.buffer_idx > -1) {
ctx.buf = context->GetScratchBuffer(context, data.buffer_idx);
}
TFLITE_DCHECK_EQ(arm_max_pool_s8(&ctx, &pool_params, &input_dims,
GetTensorData<int8_t>(input), &filter_dims,
&output_dims, GetTensorData<int8_t>(output)),
ARM_MATH_SUCCESS);
return kTfLiteOk;
}
} // namespace
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(OpData));
}
TfLiteStatus MaxPrepare(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
TFLITE_DCHECK(node->builtin_data != nullptr);
OpData* data = static_cast<OpData*>(node->user_data);
auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, input, output, data));
return kTfLiteOk;
}
TfLiteStatus AveragePrepare(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
TFLITE_DCHECK(node->builtin_data != nullptr);
OpData* data = static_cast<OpData*>(node->user_data);
auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, input, output, data));
if (input->type == kTfLiteInt8) {
RuntimeShape input_shape = GetTensorShape(input);
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
RuntimeShape output_shape = GetTensorShape(output);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
const int depth = MatchingDim(input_shape, 3, output_shape, 3);
const int output_width = output_shape.Dims(2);
const int32_t buffer_size =
arm_avgpool_s8_get_buffer_size(output_width, depth);
if (buffer_size > 0) {
TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
context, buffer_size, &data->buffer_idx));
} else {
data->buffer_idx = -1;
}
}
return kTfLiteOk;
}
TfLiteStatus AverageEval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
const OpData& data = *(static_cast<const OpData*>(node->user_data));
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
// Inputs and outputs share the same type, guaranteed by the converter.
switch (input->type) {
case kTfLiteFloat32:
AverageEvalFloat(context, node, params, data, input, output);
break;
case kTfLiteUInt8:
case kTfLiteInt8:
AverageEvalQuantized(context, node, params, data, input, output);
break;
default:
TF_LITE_KERNEL_LOG(context, "Input type %s is not currently supported",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
return kTfLiteOk;
}
TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
const OpData& data = *(static_cast<const OpData*>(node->user_data));
TfLiteTensor* input = &context->tensors[flatbuffers::EndianScalar(
node->inputs->data[kInputTensor])];
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
switch (input->type) {
case kTfLiteFloat32:
MaxEvalFloat(context, node, params, data, input, output);
break;
case kTfLiteUInt8:
MaxEvalQuantizedUInt8(context, node, params, data, input, output);
break;
case kTfLiteInt8:
MaxEvalInt8(context, node, params, data, input, output);
break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
return kTfLiteOk;
}
} // namespace pooling
TfLiteRegistration Register_AVERAGE_POOL_2D() {
return {/*init=*/pooling::Init,
/*free=*/nullptr,
/*prepare=*/pooling::AveragePrepare,
/*invoke=*/pooling::AverageEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
}
TfLiteRegistration Register_MAX_POOL_2D() {
return {/*init=*/pooling::Init,
/*free=*/nullptr,
/*prepare=*/pooling::MaxPrepare,
/*invoke=*/pooling::MaxEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
}
} // namespace micro
} // namespace ops
} // namespace tflite