blob: 524e8da864729dc8805be3f146484a1ff126d4d1 [file] [log] [blame]
/*
* Copyright (C) 2019 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef ANDROID_FRAMEWORKS_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_EXTENSIONS_H
#define ANDROID_FRAMEWORKS_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_EXTENSIONS_H
#include "NeuralNetworksExtensions.h"
#include "NeuralNetworksWrapper.h"
#include <variant>
namespace android {
namespace nn {
namespace extension_wrapper {
using wrapper::SymmPerChannelQuantParams;
using wrapper::Type;
struct ExtensionOperandParams {
std::vector<uint8_t> data;
ExtensionOperandParams(std::vector<uint8_t> data) : data(std::move(data)) {}
template <typename T>
ExtensionOperandParams(const T& data)
: ExtensionOperandParams(
std::vector(reinterpret_cast<const uint8_t*>(&data),
reinterpret_cast<const uint8_t*>(&data) + sizeof(data))) {
static_assert(std::is_trivially_copyable<T>::value, "data must be trivially copyable");
}
};
struct OperandType {
using ExtraParams =
std::variant<std::monostate, SymmPerChannelQuantParams, ExtensionOperandParams>;
ANeuralNetworksOperandType operandType;
std::vector<uint32_t> dimensions;
ExtraParams extraParams;
OperandType(const OperandType& other)
: operandType(other.operandType),
dimensions(other.dimensions),
extraParams(other.extraParams) {
operandType.dimensions = dimensions.size() > 0 ? dimensions.data() : nullptr;
}
OperandType& operator=(const OperandType& other) {
if (this != &other) {
operandType = other.operandType;
dimensions = other.dimensions;
extraParams = other.extraParams;
operandType.dimensions = dimensions.size() > 0 ? dimensions.data() : nullptr;
}
return *this;
}
OperandType(Type type, std::vector<uint32_t> d, float scale = 0.0f, int32_t zeroPoint = 0,
ExtraParams&& extraParams = std::monostate())
: dimensions(std::move(d)), extraParams(std::move(extraParams)) {
operandType = {
.type = static_cast<int32_t>(type),
.dimensionCount = static_cast<uint32_t>(dimensions.size()),
.dimensions = dimensions.size() > 0 ? dimensions.data() : nullptr,
.scale = scale,
.zeroPoint = zeroPoint,
};
}
OperandType(Type type, std::vector<uint32_t> dimensions, float scale, int32_t zeroPoint,
SymmPerChannelQuantParams&& channelQuant)
: OperandType(type, dimensions, scale, zeroPoint, ExtraParams(std::move(channelQuant))) {}
OperandType(Type type, std::vector<uint32_t> dimensions, ExtraParams&& extraParams)
: OperandType(type, dimensions, 0.0f, 0, std::move(extraParams)) {}
};
class Model : public wrapper::Model {
public:
using wrapper::Model::Model; // Inherit constructors.
int32_t getExtensionOperandType(const char* extensionName, uint16_t typeWithinExtension) {
int32_t result;
if (ANeuralNetworksModel_getExtensionOperandType(mModel, extensionName, typeWithinExtension,
&result) != ANEURALNETWORKS_NO_ERROR) {
mValid = false;
}
return result;
}
ANeuralNetworksOperationType getExtensionOperationType(const char* extensionName,
uint16_t typeWithinExtension) {
ANeuralNetworksOperationType result;
if (ANeuralNetworksModel_getExtensionOperationType(mModel, extensionName,
typeWithinExtension,
&result) != ANEURALNETWORKS_NO_ERROR) {
mValid = false;
}
return result;
}
uint32_t addOperand(const OperandType* type) {
if (ANeuralNetworksModel_addOperand(mModel, &(type->operandType)) !=
ANEURALNETWORKS_NO_ERROR) {
mValid = false;
}
if (std::holds_alternative<SymmPerChannelQuantParams>(type->extraParams)) {
const auto& channelQuant = std::get<SymmPerChannelQuantParams>(type->extraParams);
if (ANeuralNetworksModel_setOperandSymmPerChannelQuantParams(
mModel, mNextOperandId, &channelQuant.params) != ANEURALNETWORKS_NO_ERROR) {
mValid = false;
}
} else if (std::holds_alternative<ExtensionOperandParams>(type->extraParams)) {
const auto& extension = std::get<ExtensionOperandParams>(type->extraParams);
if (ANeuralNetworksModel_setOperandExtensionData(
mModel, mNextOperandId, extension.data.data(), extension.data.size()) !=
ANEURALNETWORKS_NO_ERROR) {
mValid = false;
}
}
return mNextOperandId++;
}
};
} // namespace extension_wrapper
namespace wrapper {
using ExtensionModel = extension_wrapper::Model;
using ExtensionOperandType = extension_wrapper::OperandType;
using ExtensionOperandParams = extension_wrapper::ExtensionOperandParams;
} // namespace wrapper
} // namespace nn
} // namespace android
#endif // ANDROID_FRAMEWORKS_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_EXTENSIONS_H