blob: 36d255224130861e27dc780740245d4c29de9eb8 [file] [log] [blame]
/*
* Copyright (C) 2019 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#define LOG_TAG "FibonacciDriver"
#include "FibonacciDriver.h"
#include <vector>
#include "FibonacciExtension.h"
#include "HalInterfaces.h"
#include "NeuralNetworksExtensions.h"
#include "OperationResolver.h"
#include "OperationsUtils.h"
#include "Utils.h"
#include "ValidateHal.h"
namespace android {
namespace nn {
namespace sample_driver {
namespace {
using namespace hal;
const uint8_t kLowBitsType = static_cast<uint8_t>(Model::ExtensionTypeEncoding::LOW_BITS_TYPE);
const uint32_t kTypeWithinExtensionMask = (1 << kLowBitsType) - 1;
namespace fibonacci_op {
constexpr char kOperationName[] = "EXAMPLE_FIBONACCI";
constexpr uint32_t kNumInputs = 1;
constexpr uint32_t kInputN = 0;
constexpr uint32_t kNumOutputs = 1;
constexpr uint32_t kOutputTensor = 0;
bool getFibonacciExtensionPrefix(const Model& model, uint16_t* prefix) {
NN_RET_CHECK_EQ(model.extensionNameToPrefix.size(), 1u); // Assumes no other extensions in use.
NN_RET_CHECK_EQ(model.extensionNameToPrefix[0].name, EXAMPLE_FIBONACCI_EXTENSION_NAME);
*prefix = model.extensionNameToPrefix[0].prefix;
return true;
}
bool isFibonacciOperation(const Operation& operation, const Model& model) {
int32_t operationType = static_cast<int32_t>(operation.type);
uint16_t prefix;
NN_RET_CHECK(getFibonacciExtensionPrefix(model, &prefix));
NN_RET_CHECK_EQ(operationType, (prefix << kLowBitsType) | EXAMPLE_FIBONACCI);
return true;
}
bool validate(const Operation& operation, const Model& model) {
NN_RET_CHECK(isFibonacciOperation(operation, model));
NN_RET_CHECK_EQ(operation.inputs.size(), kNumInputs);
NN_RET_CHECK_EQ(operation.outputs.size(), kNumOutputs);
int32_t inputType = static_cast<int32_t>(model.operands[operation.inputs[0]].type);
int32_t outputType = static_cast<int32_t>(model.operands[operation.outputs[0]].type);
uint16_t prefix;
NN_RET_CHECK(getFibonacciExtensionPrefix(model, &prefix));
NN_RET_CHECK(inputType == ((prefix << kLowBitsType) | EXAMPLE_INT64) ||
inputType == ANEURALNETWORKS_TENSOR_FLOAT32);
NN_RET_CHECK(outputType == ((prefix << kLowBitsType) | EXAMPLE_TENSOR_QUANT64_ASYMM) ||
outputType == ANEURALNETWORKS_TENSOR_FLOAT32);
return true;
}
bool prepare(IOperationExecutionContext* context) {
int64_t n;
if (context->getInputType(kInputN) == OperandType::TENSOR_FLOAT32) {
n = static_cast<int64_t>(context->getInputValue<float>(kInputN));
} else {
n = context->getInputValue<int64_t>(kInputN);
}
NN_RET_CHECK_GE(n, 1);
Shape output = context->getOutputShape(kOutputTensor);
output.dimensions = {static_cast<uint32_t>(n)};
return context->setOutputShape(kOutputTensor, output);
}
template <typename ScaleT, typename ZeroPointT, typename OutputT>
bool compute(int32_t n, ScaleT outputScale, ZeroPointT outputZeroPoint, OutputT* output) {
// Compute the Fibonacci numbers.
if (n >= 1) {
output[0] = 1;
}
if (n >= 2) {
output[1] = 1;
}
if (n >= 3) {
for (int32_t i = 2; i < n; ++i) {
output[i] = output[i - 1] + output[i - 2];
}
}
// Quantize output.
for (int32_t i = 0; i < n; ++i) {
output[i] = output[i] / outputScale + outputZeroPoint;
}
return true;
}
bool execute(IOperationExecutionContext* context) {
int64_t n;
if (context->getInputType(kInputN) == OperandType::TENSOR_FLOAT32) {
n = static_cast<int64_t>(context->getInputValue<float>(kInputN));
} else {
n = context->getInputValue<int64_t>(kInputN);
}
if (context->getOutputType(kOutputTensor) == OperandType::TENSOR_FLOAT32) {
float* output = context->getOutputBuffer<float>(kOutputTensor);
return compute(n, /*scale=*/1.0, /*zeroPoint=*/0, output);
} else {
uint64_t* output = context->getOutputBuffer<uint64_t>(kOutputTensor);
Shape outputShape = context->getOutputShape(kOutputTensor);
auto outputQuant = reinterpret_cast<const ExampleQuant64AsymmParams*>(
outputShape.extraParams.extension().data());
return compute(n, outputQuant->scale, outputQuant->zeroPoint, output);
}
}
} // namespace fibonacci_op
} // namespace
const OperationRegistration* FibonacciOperationResolver::findOperation(
OperationType operationType) const {
// .validate is omitted because it's not used by the extension driver.
static OperationRegistration operationRegistration(operationType, fibonacci_op::kOperationName,
nullptr, fibonacci_op::prepare,
fibonacci_op::execute, {});
uint16_t prefix = static_cast<int32_t>(operationType) >> kLowBitsType;
uint16_t typeWithinExtension = static_cast<int32_t>(operationType) & kTypeWithinExtensionMask;
// Assumes no other extensions in use.
return prefix != 0 && typeWithinExtension == EXAMPLE_FIBONACCI ? &operationRegistration
: nullptr;
}
Return<void> FibonacciDriver::getSupportedExtensions(getSupportedExtensions_cb cb) {
cb(ErrorStatus::NONE,
{
{
.name = EXAMPLE_FIBONACCI_EXTENSION_NAME,
.operandTypes =
{
{
.type = EXAMPLE_INT64,
.isTensor = false,
.byteSize = 8,
},
{
.type = EXAMPLE_TENSOR_QUANT64_ASYMM,
.isTensor = true,
.byteSize = 8,
},
},
},
});
return Void();
}
Return<void> FibonacciDriver::getCapabilities_1_3(getCapabilities_1_3_cb cb) {
android::nn::initVLogMask();
VLOG(DRIVER) << "getCapabilities()";
static const PerformanceInfo kPerf = {.execTime = 1.0f, .powerUsage = 1.0f};
Capabilities capabilities = {
.relaxedFloat32toFloat16PerformanceScalar = kPerf,
.relaxedFloat32toFloat16PerformanceTensor = kPerf,
.operandPerformance = nonExtensionOperandPerformance<HalVersion::V1_3>(kPerf)};
cb(ErrorStatus::NONE, capabilities);
return Void();
}
Return<void> FibonacciDriver::getSupportedOperations_1_3(const V1_3::Model& model,
getSupportedOperations_1_3_cb cb) {
VLOG(DRIVER) << "getSupportedOperations()";
if (!validateModel(model)) {
cb(ErrorStatus::INVALID_ARGUMENT, {});
return Void();
}
const size_t count = model.operations.size();
std::vector<bool> supported(count);
for (size_t i = 0; i < count; ++i) {
const Operation& operation = model.operations[i];
if (fibonacci_op::isFibonacciOperation(operation, model)) {
if (!fibonacci_op::validate(operation, model)) {
cb(ErrorStatus::INVALID_ARGUMENT, {});
return Void();
}
supported[i] = true;
}
}
cb(ErrorStatus::NONE, supported);
return Void();
}
} // namespace sample_driver
} // namespace nn
} // namespace android