blob: 591617fb52b584f405985dce246653c8a2f95761 [file] [log] [blame]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "CpuOperationUtils.h"
#include "Operations.h"
#include <cfloat>
#include <cmath>
#include "Tracing.h"
namespace android {
namespace nn {
template <typename T_Input>
bool roiPoolingImpl(const T_Input* inputData, const Shape& inputShape, const float* roiData,
const Shape& roiShape, float spatialScale, T_Input* outputData,
const Shape& outputShape) {
NNTRACE_TRANS("RoiPooling");
const uint32_t kRoiDim = 4;
uint32_t inHeight = getSizeOfDimension(inputShape, 1);
uint32_t inWidth = getSizeOfDimension(inputShape, 2);
uint32_t inDepth = getSizeOfDimension(inputShape, 3);
uint32_t outHeight = getSizeOfDimension(outputShape, 1);
uint32_t outWidth = getSizeOfDimension(outputShape, 2);
uint32_t numRois = getSizeOfDimension(roiShape, 0);
uint32_t roiInfoLength = getSizeOfDimension(roiShape, 1);
T_Input* outPtr = outputData;
const float* roiDataEnd = roiData + numRois * roiInfoLength;
for (const float* roiInfo = roiData; roiInfo < roiDataEnd; roiInfo += kRoiDim) {
uint32_t batchId = 0;
// get optional batch id
if (roiInfoLength == kRoiDim + 1) {
batchId = std::round(roiInfo[0]);
roiInfo++;
}
const T_Input* batchBase = inputData + batchId * inHeight * inWidth * inDepth;
int32_t wRoiStart = std::round(roiInfo[0] * spatialScale);
int32_t hRoiStart = std::round(roiInfo[1] * spatialScale);
int32_t wRoiEnd = std::round(roiInfo[2] * spatialScale);
int32_t hRoiEnd = std::round(roiInfo[3] * spatialScale);
// Rois with width/height < 1 are considered malformed and are forced to be 1
float roiWidth = static_cast<float>(std::max(wRoiEnd - wRoiStart + 1, 1));
float roiHeight = static_cast<float>(std::max(hRoiEnd - hRoiStart + 1, 1));
float wStepSize = roiWidth / static_cast<float>(outWidth);
float hStepSize = roiHeight / static_cast<float>(outHeight);
for (uint32_t i = 0; i < outHeight; i++) {
for (uint32_t j = 0; j < outWidth; j++) {
// Take floor on start, ceil on end, start included, end excluded, i.e. [start, end)
// end is guaranteed to larger than start by at least 1
uint32_t wStart = std::floor(wStepSize * j + wRoiStart);
uint32_t wEnd = std::ceil(wStepSize * (j + 1) + wRoiStart);
uint32_t hStart = std::floor(hStepSize * i + hRoiStart);
uint32_t hEnd = std::ceil(hStepSize * (i + 1) + hRoiStart);
wStart = std::min(wStart, inWidth);
wEnd = std::min(wEnd, inWidth);
hStart = std::min(hStart, inHeight);
hEnd = std::min(hEnd, inHeight);
for (uint32_t k = 0; k < inDepth; k++) {
T_Input maxValue;
bool first = true;
for (uint32_t h = hStart; h < hEnd; h++) {
for (uint32_t w = wStart; w < wEnd; w++) {
T_Input inputValue = batchBase[h * inWidth * inDepth + w * inDepth + k];
if (first || inputValue > maxValue) {
maxValue = inputValue;
first = false;
}
}
}
outPtr[k] = maxValue;
}
outPtr += inDepth;
}
}
}
return true;
}
bool roiPoolingGeneric(const uint8_t* inputData, const Shape& inputShape, const uint8_t* roiData,
const Shape& roiShape, float spatialScale, uint8_t* outputData,
const Shape& outputShape) {
NNTRACE_TRANS("roiPoolingGeneric");
if (inputShape.type == OperandType::TENSOR_FLOAT32) {
return roiPoolingImpl<float>(reinterpret_cast<const float*>(inputData), inputShape,
reinterpret_cast<const float*>(roiData), roiShape,
spatialScale, reinterpret_cast<float*>(outputData),
outputShape);
} else if (inputShape.type == OperandType::TENSOR_QUANT8_ASYMM) {
return roiPoolingImpl<uint8_t>(reinterpret_cast<const uint8_t*>(inputData), inputShape,
reinterpret_cast<const float*>(roiData), roiShape,
spatialScale, reinterpret_cast<uint8_t*>(outputData),
outputShape);
} else {
LOG(ERROR) << "Unsupported data type";
return false;
}
}
} // namespace nn
} // namespace android