| /* |
| * Copyright (C) 2020 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "Validation.h" |
| |
| #include <android-base/logging.h> |
| #include <android-base/mapped_file.h> |
| |
| #include <algorithm> |
| #include <cctype> |
| #include <functional> |
| #include <limits> |
| #include <memory> |
| #include <numeric> |
| #include <set> |
| #include <sstream> |
| #include <string> |
| #include <string_view> |
| #include <tuple> |
| #include <utility> |
| #include <variant> |
| #include <vector> |
| |
| #include "ControlFlow.h" |
| #include "OperandTypes.h" |
| #include "OperationResolver.h" |
| #include "OperationTypes.h" |
| #include "Result.h" |
| #include "SharedMemory.h" |
| #include "TypeUtils.h" |
| #include "Types.h" |
| |
| // The NN_VALIDATE family of macros defined below is similar to the CHECK family defined in |
| // system/libbase/include/android-base/logging.h |
| // |
| // The difference is that NN_VALIDATE macros use LOG(ERROR) instead of LOG(FATAL) |
| // and return false instead of aborting. |
| |
| // Logs an error and returns false or INVALID. Append context using << after. For example: |
| // |
| // NN_VALIDATE_FAIL() << "Something went wrong"; |
| // |
| // The containing function must return a bool or Version. |
| #define NN_VALIDATE_FAIL() \ |
| return NN_ERROR() << "NN_VALIDATE failed (" << __FILE__ << ":" << __LINE__ << "): " |
| |
| // Logs an error and returns false or Version::INVALID if condition is false. Extra logging can be |
| // appended using << after. For example: |
| // |
| // NN_VALIDATE(false) << "Something went wrong"; |
| // |
| // The containing function must return a bool. |
| #define NN_VALIDATE(condition) \ |
| while (UNLIKELY(!(condition))) NN_VALIDATE_FAIL() << #condition << " " |
| |
| // Helper for NN_VALIDATE_xx(x, y) macros. |
| #define NN_VALIDATE_OP(LHS, RHS, OP) \ |
| for (auto _values = ::android::base::MakeEagerEvaluator(LHS, RHS); \ |
| UNLIKELY(!(_values.lhs.v OP _values.rhs.v)); \ |
| /* empty */) \ |
| NN_VALIDATE_FAIL() \ |
| << #LHS << " " << #OP << " " << #RHS << " (" << #LHS << " = " \ |
| << ::android::base::LogNullGuard<decltype(_values.lhs.v)>::Guard(_values.lhs.v) \ |
| << ", " << #RHS << " = " \ |
| << ::android::base::LogNullGuard<decltype(_values.rhs.v)>::Guard(_values.rhs.v) \ |
| << ") " |
| |
| // Logs an error and returns false or Version::INVALID if a condition between x and y does not hold. |
| // Extra logging can be appended using << after. For example: |
| // |
| // NN_VALIDATE_EQ(a, b) << "Something went wrong"; |
| // |
| // The values must implement the appropriate comparison operator as well as |
| // `operator<<(std::ostream&, ...)`. |
| // The containing function must return a bool or Version. |
| #define NN_VALIDATE_EQ(x, y) NN_VALIDATE_OP(x, y, ==) |
| #define NN_VALIDATE_NE(x, y) NN_VALIDATE_OP(x, y, !=) |
| #define NN_VALIDATE_LE(x, y) NN_VALIDATE_OP(x, y, <=) |
| #define NN_VALIDATE_LT(x, y) NN_VALIDATE_OP(x, y, <) |
| #define NN_VALIDATE_GE(x, y) NN_VALIDATE_OP(x, y, >=) |
| #define NN_VALIDATE_GT(x, y) NN_VALIDATE_OP(x, y, >) |
| |
| namespace android::nn { |
| namespace { |
| |
| constexpr auto kNullptrVariant = std::variant<const void*, void*>{}; |
| constexpr auto kInvalidMemoryDomainToken = Request::MemoryDomainToken{}; |
| |
| template <typename Type, typename ValidationFunction> |
| Result<Version> validateVector(const std::vector<Type>& objects, |
| const ValidationFunction& validationFunction) { |
| auto version = Version::ANDROID_OC_MR1; |
| for (const auto& object : objects) { |
| version = combineVersions(version, NN_TRY(validationFunction(object))); |
| } |
| return version; |
| } |
| |
| bool isValidExtensionName(const std::string& name) { |
| constexpr auto validSymbol = [](char symbol) { |
| return std::islower(symbol) || std::isdigit(symbol) || symbol == '.' || symbol == '_'; |
| }; |
| const bool hasOnlyValidSymbols = std::all_of(name.begin(), name.end(), validSymbol); |
| const bool hasAtLeastOnePeriod = std::find(name.begin(), name.end(), '.') != name.end(); |
| return hasOnlyValidSymbols && hasAtLeastOnePeriod; |
| } |
| |
| Result<Version> validateDeviceStatus(const DeviceStatus& deviceStatus) { |
| switch (deviceStatus) { |
| case DeviceStatus::AVAILABLE: |
| case DeviceStatus::BUSY: |
| case DeviceStatus::OFFLINE: |
| case DeviceStatus::UNKNOWN: |
| return Version::ANDROID_OC_MR1; |
| } |
| NN_VALIDATE_FAIL() << "Invalid DeviceStatus " << deviceStatus; |
| } |
| |
| Result<Version> validateExecutionPreference(const ExecutionPreference& executionPreference) { |
| switch (executionPreference) { |
| case ExecutionPreference::FAST_SINGLE_ANSWER: |
| // ExecutionPreference::FAST_SINGLE_ANSWER is the default value, so it is implicitly |
| // valid for all versions. |
| return Version::ANDROID_OC_MR1; |
| case ExecutionPreference::LOW_POWER: |
| case ExecutionPreference::SUSTAINED_SPEED: |
| return Version::ANDROID_P; |
| } |
| NN_VALIDATE_FAIL() << "Invalid ExecutionPreference " << executionPreference; |
| } |
| |
| Result<Version> validateDeviceType(const DeviceType& deviceType) { |
| switch (deviceType) { |
| case DeviceType::UNKNOWN: |
| // DeviceType was introduced in the 1.2 NN HAL. DeviceType::UNKNOWN is returned when |
| // querying versions that are prior to the 1.2 NN HAL. DeviceType::UNKNOWN is not a |
| // valid code to return for a driver that implement at least a 1.2 NN HAL. If we need a |
| // range of versions, make ANDROID_Q (NN HAL 1.2) the exclusive upper bound for |
| // DeviceType::UNKNOWN. |
| return Version::ANDROID_OC_MR1; |
| case DeviceType::OTHER: |
| case DeviceType::CPU: |
| case DeviceType::GPU: |
| case DeviceType::ACCELERATOR: |
| return Version::ANDROID_Q; |
| } |
| NN_VALIDATE_FAIL() << "Invalid DeviceType " << deviceType; |
| } |
| |
| Result<Version> validateMeasureTiming(const MeasureTiming& measureTiming) { |
| switch (measureTiming) { |
| case MeasureTiming::NO: |
| // MeasureTiming::NO is the default value, so it is implicitly valid for all versions. |
| return Version::ANDROID_OC_MR1; |
| case MeasureTiming::YES: |
| return Version::ANDROID_Q; |
| } |
| NN_VALIDATE_FAIL() << "Invalid MeasureTiming " << measureTiming; |
| } |
| |
| Result<Version> validateOperandType(const OperandType& operandType) { |
| switch (operandType) { |
| case OperandType::FLOAT32: |
| case OperandType::INT32: |
| case OperandType::UINT32: |
| case OperandType::TENSOR_FLOAT32: |
| case OperandType::TENSOR_INT32: |
| case OperandType::TENSOR_QUANT8_ASYMM: |
| case OperandType::OEM: |
| case OperandType::TENSOR_OEM_BYTE: |
| return Version::ANDROID_OC_MR1; |
| case OperandType::BOOL: |
| case OperandType::TENSOR_QUANT16_SYMM: |
| case OperandType::TENSOR_FLOAT16: |
| case OperandType::TENSOR_BOOL8: |
| case OperandType::FLOAT16: |
| case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: |
| case OperandType::TENSOR_QUANT16_ASYMM: |
| case OperandType::TENSOR_QUANT8_SYMM: |
| return Version::ANDROID_Q; |
| case OperandType::TENSOR_QUANT8_ASYMM_SIGNED: |
| case OperandType::SUBGRAPH: |
| return Version::ANDROID_R; |
| } |
| if (isExtension(operandType)) { |
| return Version::ANDROID_Q; |
| } |
| NN_VALIDATE_FAIL() << "Invalid OperandType " << operandType; |
| } |
| |
| Result<Version> validateOperandLifeTime(const Operand& operand) { |
| // Make sure SUBGRAPH operand type and lifetime always go together. |
| NN_VALIDATE_EQ((operand.type == OperandType::SUBGRAPH), |
| (operand.lifetime == Operand::LifeTime::SUBGRAPH)) |
| << "Operand of type " << operand.type << " cannot have lifetime " << operand.lifetime; |
| |
| switch (operand.lifetime) { |
| case Operand::LifeTime::TEMPORARY_VARIABLE: |
| case Operand::LifeTime::SUBGRAPH_INPUT: |
| case Operand::LifeTime::SUBGRAPH_OUTPUT: |
| case Operand::LifeTime::CONSTANT_COPY: |
| case Operand::LifeTime::CONSTANT_REFERENCE: |
| case Operand::LifeTime::NO_VALUE: |
| case Operand::LifeTime::POINTER: |
| return Version::ANDROID_OC_MR1; |
| case Operand::LifeTime::SUBGRAPH: |
| return Version::ANDROID_R; |
| } |
| NN_VALIDATE_FAIL() << "Invalid Operand::LifeTime " << operand.lifetime; |
| } |
| |
| Result<Version> validatePriority(const Priority& priority) { |
| switch (priority) { |
| case Priority::MEDIUM: |
| // Priority::MEDIUM is the default value, so it is implicitly valid for all versions. |
| return Version::ANDROID_OC_MR1; |
| case Priority::LOW: |
| case Priority::HIGH: |
| return Version::ANDROID_R; |
| } |
| NN_VALIDATE_FAIL() << "Invalid Priority " << priority; |
| } |
| |
| Result<Version> validateErrorStatus(const ErrorStatus& errorStatus) { |
| // Note that MISSED_DEADLINE_*, RESOURCE_EXHAUSTED_*, and DEAD_OBJECT were introduced ih |
| // ANDROID_R, but these can be cast to ANDROID_OC_MR1 as GENERAL_FAILURE. |
| switch (errorStatus) { |
| case ErrorStatus::NONE: |
| case ErrorStatus::DEVICE_UNAVAILABLE: |
| case ErrorStatus::GENERAL_FAILURE: |
| case ErrorStatus::OUTPUT_INSUFFICIENT_SIZE: |
| case ErrorStatus::INVALID_ARGUMENT: |
| case ErrorStatus::MISSED_DEADLINE_TRANSIENT: |
| case ErrorStatus::MISSED_DEADLINE_PERSISTENT: |
| case ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT: |
| case ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT: |
| case ErrorStatus::DEAD_OBJECT: |
| return Version::ANDROID_OC_MR1; |
| } |
| NN_VALIDATE_FAIL() << "Invalid ErrorStatus " << errorStatus; |
| } |
| |
| Result<Version> validateFusedActivationFunc(const FusedActivationFunc& activation) { |
| switch (activation) { |
| case FusedActivationFunc::NONE: |
| case FusedActivationFunc::RELU: |
| case FusedActivationFunc::RELU1: |
| case FusedActivationFunc::RELU6: |
| return Version::ANDROID_OC_MR1; |
| } |
| NN_VALIDATE_FAIL() << "Invalid FusedActivationFunc " << activation; |
| } |
| |
| Result<Version> validateOutputShape(const OutputShape& /*outputShape*/) { |
| return Version::ANDROID_Q; |
| } |
| |
| Result<Version> validateTiming(const Timing& timing) { |
| constexpr auto kNoTiming = Timing{}; |
| if (timing == kNoTiming) { |
| // kNoTiming is the default value, so it is implicitly valid for all versions. |
| return Version::ANDROID_OC_MR1; |
| } |
| if (timing.timeInDriver.has_value() && timing.timeOnDevice.has_value()) { |
| // `lazyMessage` is a lazy function to produce the timing validation error message. |
| // Currently, the code is not able to inline the message in NN_VALIDATE due to a |
| // argument-dependent lookup issue with nn::detail::ErrorBuilder interacting with std types |
| // such as std::chrono::duration, so this function uses an indirection through |
| // std::ostringstream. |
| const auto lazyMessage = [&timing]() -> std::string { |
| std::ostringstream oss; |
| oss << "Timing::timeOnDevice (" << timing.timeOnDevice.value() |
| << ") must not exceed Timing::timeInDriver (" << timing.timeInDriver.value() << ")"; |
| return oss.str(); |
| }; |
| NN_VALIDATE(timing.timeOnDevice.value() <= timing.timeInDriver.value()) << lazyMessage(); |
| } |
| return Version::ANDROID_Q; |
| } |
| |
| Result<Version> validateCapabilitiesPerformanceInfo( |
| const Capabilities::PerformanceInfo& performanceInfo) { |
| NN_VALIDATE_GT(performanceInfo.execTime, 0.0f); |
| NN_VALIDATE_GT(performanceInfo.powerUsage, 0.0f); |
| return Version::ANDROID_OC_MR1; |
| } |
| |
| Result<Version> validateCapabilitiesOperandPerformance( |
| const Capabilities::OperandPerformance& operandPerformance) { |
| auto version = NN_TRY(validateOperandType(operandPerformance.type)); |
| return combineVersions(version, |
| NN_TRY(validateCapabilitiesPerformanceInfo(operandPerformance.info))); |
| } |
| |
| Result<Version> validateCapabilitiesOperandPerformanceTable( |
| const Capabilities::OperandPerformanceTable& operandPerformances) { |
| // OperandPerformanceTable's order was validated when it was created, and it is castable to any |
| // version. If an OperandType does not exist in the lower version being converted to, that |
| // OperandPerformance will be dropped. |
| NN_TRY(validateVector(operandPerformances.asVector(), validateCapabilitiesOperandPerformance)); |
| return Version::ANDROID_OC_MR1; |
| } |
| |
| Result<Version> validateCapabilities(const Capabilities& capabilities) { |
| auto version = |
| NN_TRY(validateCapabilitiesOperandPerformanceTable(capabilities.operandPerformance)); |
| |
| version = combineVersions(version, |
| NN_TRY(validateCapabilitiesPerformanceInfo( |
| capabilities.relaxedFloat32toFloat16PerformanceScalar))); |
| version = combineVersions(version, |
| NN_TRY(validateCapabilitiesPerformanceInfo( |
| capabilities.relaxedFloat32toFloat16PerformanceTensor))); |
| version = combineVersions( |
| version, NN_TRY(validateCapabilitiesPerformanceInfo(capabilities.ifPerformance))); |
| version = combineVersions( |
| version, NN_TRY(validateCapabilitiesPerformanceInfo(capabilities.whilePerformance))); |
| |
| return version; |
| } |
| |
| Result<Version> validateExtensionOperandTypeInformation( |
| const Extension::OperandTypeInformation& operandTypeInformation) { |
| NN_VALIDATE_GT(operandTypeInformation.byteSize, 0u); |
| return Version::ANDROID_Q; |
| } |
| |
| Result<Version> validateExtension(const Extension& extension) { |
| NN_VALIDATE(isValidExtensionName(extension.name)); |
| |
| // Verify all OperandTypeInformations have unique types. |
| std::vector<uint16_t> types; |
| types.reserve(extension.operandTypes.size()); |
| std::transform(extension.operandTypes.begin(), extension.operandTypes.end(), |
| std::back_inserter(types), |
| [](const Extension::OperandTypeInformation& operandTypeInformation) { |
| return operandTypeInformation.type; |
| }); |
| std::sort(types.begin(), types.end()); |
| const auto iter = std::adjacent_find(types.begin(), types.end()); |
| NN_VALIDATE(iter == types.end()) << "Extension has duplicate type " << *iter; |
| |
| return combineVersions(Version::ANDROID_Q, |
| NN_TRY(validateVector(extension.operandTypes, |
| validateExtensionOperandTypeInformation))); |
| } |
| |
| Result<Version> validateExtensions(const std::vector<Extension>& extensions) { |
| const auto version = NN_TRY(validateVector(extensions, validateExtension)); |
| |
| // Verify all extensions have unique names. |
| std::vector<std::reference_wrapper<const std::string>> names; |
| names.reserve(extensions.size()); |
| std::transform(extensions.begin(), extensions.end(), std::back_inserter(names), |
| [](const Extension& extension) { return std::cref(extension.name); }); |
| std::sort(names.begin(), names.end(), std::less<std::string>{}); |
| const auto nameIter = |
| std::adjacent_find(names.begin(), names.end(), std::equal_to<std::string>{}); |
| NN_VALIDATE(nameIter == names.end()) |
| << "Two or more extensions have the duplicate name " << nameIter->get(); |
| |
| return version; |
| } |
| |
| // Forward declaration of subgraph validation function. |
| Result<Version> validateModelSubgraph(const Model::Subgraph& subgraph, |
| std::optional<size_t> referencedIndex, |
| size_t operandValuesSize, |
| const std::vector<size_t>& poolSizes, |
| const std::vector<Model::Subgraph>& referenced, |
| std::vector<std::optional<Version>>* subgraphVersionCache); |
| |
| Result<Version> validateOperandDataLocation( |
| const Operand& operand, size_t operandValuesSize, const std::vector<size_t>& poolSizes, |
| const std::vector<Model::Subgraph>& subgraphs, |
| std::vector<std::optional<Version>>* subgraphVersionCache) { |
| const DataLocation& location = operand.location; |
| NN_VALIDATE_EQ(location.padding, 0u) |
| << "DataLocation with a non-zero padding used in Model: " << location.padding; |
| switch (operand.lifetime) { |
| case Operand::LifeTime::CONSTANT_COPY: |
| NN_VALIDATE(location.pointer == kNullptrVariant) |
| << "CONSTANT_COPY with a non-null pointer"; |
| NN_VALIDATE_EQ(location.poolIndex, 0u) |
| << "CONSTANT_COPY with a non-zero poolIndex " << location.poolIndex; |
| // Do the addition using uint64_t to avoid potential wrap-around problems. |
| NN_VALIDATE_LE(static_cast<uint64_t>(location.offset) + location.length, |
| operandValuesSize) |
| << "OperandValue location out of range. Starts at " << location.offset |
| << ", length " << location.length << ", max " << operandValuesSize; |
| return Version::ANDROID_OC_MR1; |
| case Operand::LifeTime::CONSTANT_REFERENCE: |
| NN_VALIDATE_LT(location.poolIndex, poolSizes.size()); |
| // Do the addition using uint64_t to avoid potential wrap-around problems. |
| NN_VALIDATE_LE(static_cast<uint64_t>(location.offset) + location.length, |
| poolSizes[location.poolIndex]) |
| << "OperandValue location out of range. Starts at " << location.offset |
| << ", length " << location.length << ", max " << poolSizes[location.poolIndex]; |
| return Version::ANDROID_OC_MR1; |
| case Operand::LifeTime::TEMPORARY_VARIABLE: |
| case Operand::LifeTime::SUBGRAPH_INPUT: |
| case Operand::LifeTime::SUBGRAPH_OUTPUT: |
| case Operand::LifeTime::NO_VALUE: |
| NN_VALIDATE(location.pointer == kNullptrVariant) |
| << "Unexpected pointer value for operand of lifetime " << operand.lifetime; |
| NN_VALIDATE_EQ(location.poolIndex, 0u) |
| << "Unexpected poolIndex " << location.poolIndex << " for operand of lifetime " |
| << operand.lifetime; |
| NN_VALIDATE_EQ(location.offset, 0u) << "Unexpected offset " << location.offset |
| << " for operand of lifetime " << operand.lifetime; |
| NN_VALIDATE_EQ(location.length, 0u) << "Unexpected length " << location.length |
| << " for operand of lifetime " << operand.lifetime; |
| return Version::ANDROID_OC_MR1; |
| case Operand::LifeTime::SUBGRAPH: { |
| NN_VALIDATE(location.pointer == kNullptrVariant) << "SUBGRAPH with a non-null pointer"; |
| NN_VALIDATE_EQ(location.poolIndex, 0u) |
| << "SUBGRAPH with a non-zero poolIndex " << location.poolIndex; |
| NN_VALIDATE_LT(location.offset, subgraphs.size()) |
| << "Subgraph index out of range: " << location.offset |
| << " >= " << subgraphs.size(); |
| NN_VALIDATE_EQ(location.length, 0u) |
| << "SUBGRAPH with a non-zero length " << location.length; |
| const auto version = NN_TRY(validateModelSubgraph( |
| subgraphs[location.offset], location.offset, operandValuesSize, poolSizes, |
| subgraphs, subgraphVersionCache)); |
| return combineVersions(version, Version::ANDROID_R); |
| } |
| case Operand::LifeTime::POINTER: { |
| const bool nonNull = |
| std::visit([](auto* ptr) { return ptr != nullptr; }, location.pointer); |
| NN_VALIDATE(nonNull) << "POINTER with a null pointer"; |
| NN_VALIDATE_EQ(location.poolIndex, 0u) |
| << "POINTER with a non-zero poolIndex " << location.poolIndex; |
| NN_VALIDATE_EQ(location.offset, 0u) |
| << "POINTER with a non-zero offset " << location.offset; |
| return Version::ANDROID_OC_MR1; |
| } |
| } |
| NN_VALIDATE_FAIL() << "Invalid Operand::LifeTime " << operand.lifetime; |
| } |
| |
| Result<Version> validateOperandDimensions(const Operand& operand) { |
| switch (operand.type) { |
| case OperandType::FLOAT32: |
| case OperandType::INT32: |
| case OperandType::UINT32: |
| case OperandType::BOOL: |
| case OperandType::FLOAT16: |
| case OperandType::SUBGRAPH: |
| case OperandType::OEM: |
| NN_VALIDATE(operand.dimensions.empty()) |
| << "Scalar data has dimensions of rank " << operand.dimensions.size(); |
| return Version::ANDROID_OC_MR1; |
| case OperandType::TENSOR_FLOAT32: |
| case OperandType::TENSOR_INT32: |
| case OperandType::TENSOR_QUANT8_ASYMM: |
| case OperandType::TENSOR_QUANT16_SYMM: |
| case OperandType::TENSOR_FLOAT16: |
| case OperandType::TENSOR_BOOL8: |
| case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: |
| case OperandType::TENSOR_QUANT16_ASYMM: |
| case OperandType::TENSOR_QUANT8_SYMM: |
| case OperandType::TENSOR_QUANT8_ASYMM_SIGNED: |
| case OperandType::TENSOR_OEM_BYTE: { |
| if (operand.lifetime == Operand::LifeTime::CONSTANT_COPY || |
| operand.lifetime == Operand::LifeTime::CONSTANT_REFERENCE || |
| operand.lifetime == Operand::LifeTime::POINTER) { |
| NN_VALIDATE(!operand.dimensions.empty()) |
| << "Tensor has lifetime of " << operand.lifetime |
| << " but dimensions of rank 0"; |
| const auto size = getNonExtensionSize(operand); |
| NN_VALIDATE(size.has_value()) << "Tensor dimensions overflow"; |
| NN_VALIDATE_NE(size.value(), 0u) << "Tensor has at least one unknown dimension"; |
| } |
| // TODO(b/165152547): aren't NO_VALUE arguments allowed to be .empty() even before |
| // Android Q? |
| if (operand.dimensions.empty()) { |
| // Unspecified rank was added in Android Q. |
| return Version::ANDROID_Q; |
| } |
| return Version::ANDROID_OC_MR1; |
| } |
| } |
| if (isExtension(operand.type)) { |
| // Extension types were added in Android Q. |
| return Version::ANDROID_Q; |
| } |
| NN_VALIDATE_FAIL() << "Invalid OperandType " << operand.type; |
| } |
| |
| Result<Version> validateOperandScale(const Operand& operand) { |
| switch (operand.type) { |
| case OperandType::FLOAT32: |
| case OperandType::INT32: |
| case OperandType::UINT32: |
| case OperandType::TENSOR_FLOAT32: |
| case OperandType::BOOL: |
| case OperandType::TENSOR_FLOAT16: |
| case OperandType::TENSOR_BOOL8: |
| case OperandType::FLOAT16: |
| case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: |
| case OperandType::SUBGRAPH: |
| NN_VALIDATE_EQ(operand.scale, 0.0f) |
| << "Operand of type " << operand.type << " with a non-zero scale (" |
| << operand.scale << ")"; |
| return Version::ANDROID_OC_MR1; |
| case OperandType::TENSOR_INT32: |
| // TENSOR_INT32 may be used with or without scale, depending on the operation. |
| // TODO(b/119869082) We should have a separate type for TENSOR_INT32 with a scale. |
| NN_VALIDATE_GE(operand.scale, 0.0f) |
| << "Operand of type " << operand.type << " with a negative scale"; |
| return Version::ANDROID_OC_MR1; |
| case OperandType::TENSOR_QUANT8_ASYMM: |
| case OperandType::TENSOR_QUANT16_SYMM: |
| case OperandType::TENSOR_QUANT16_ASYMM: |
| case OperandType::TENSOR_QUANT8_SYMM: |
| case OperandType::TENSOR_QUANT8_ASYMM_SIGNED: |
| NN_VALIDATE_GT(operand.scale, 0.0f) |
| << "Operand of type " << operand.type << " with a non-positive scale"; |
| return Version::ANDROID_OC_MR1; |
| case OperandType::OEM: |
| case OperandType::TENSOR_OEM_BYTE: |
| // No validation for OEM types. |
| return Version::ANDROID_OC_MR1; |
| } |
| if (isExtension(operand.type)) { |
| NN_VALIDATE_EQ(operand.scale, 0.0f) << "Operand of type " << operand.type |
| << " with a non-zero scale (" << operand.scale << ")"; |
| return Version::ANDROID_Q; |
| } |
| NN_VALIDATE_FAIL() << "Invalid OperandType " << operand.type; |
| } |
| |
| Result<Version> validateOperandZeroPoint(const Operand& operand) { |
| switch (operand.type) { |
| case OperandType::FLOAT32: |
| case OperandType::INT32: |
| case OperandType::UINT32: |
| case OperandType::TENSOR_FLOAT32: |
| case OperandType::TENSOR_INT32: |
| case OperandType::BOOL: |
| case OperandType::TENSOR_FLOAT16: |
| case OperandType::TENSOR_BOOL8: |
| case OperandType::FLOAT16: |
| case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: |
| case OperandType::TENSOR_QUANT8_SYMM: |
| case OperandType::SUBGRAPH: |
| NN_VALIDATE_EQ(operand.zeroPoint, 0) |
| << "Operand of type " << operand.type << " with a non-zero zeroPoint " |
| << operand.zeroPoint; |
| return Version::ANDROID_OC_MR1; |
| case OperandType::TENSOR_QUANT8_ASYMM: |
| NN_VALIDATE(operand.zeroPoint >= 0 && operand.zeroPoint <= 255) |
| << "Operand of type " << operand.type << " with an invalid zeroPoint " |
| << operand.zeroPoint << ", must be in range [0, 255]"; |
| return Version::ANDROID_OC_MR1; |
| case OperandType::TENSOR_QUANT8_ASYMM_SIGNED: |
| NN_VALIDATE(operand.zeroPoint >= -128 && operand.zeroPoint <= 127) |
| << "Operand of type " << operand.type << " with an invalid zeroPoint " |
| << operand.zeroPoint << ", must be in range [-128, 127]"; |
| return Version::ANDROID_OC_MR1; |
| case OperandType::TENSOR_QUANT16_ASYMM: |
| NN_VALIDATE(operand.zeroPoint >= 0 && operand.zeroPoint <= 65535) |
| << "Operand of type " << operand.type << " with an invalid zeroPoint " |
| << operand.zeroPoint << ", must be in range [0, 65535]"; |
| return Version::ANDROID_OC_MR1; |
| case OperandType::TENSOR_QUANT16_SYMM: |
| NN_VALIDATE_EQ(operand.zeroPoint, 0) |
| << "Operand of type " << operand.type << " with a non-zero zeroPoint " |
| << operand.zeroPoint; |
| return Version::ANDROID_OC_MR1; |
| case OperandType::OEM: |
| case OperandType::TENSOR_OEM_BYTE: |
| // No validation for OEM types. |
| return Version::ANDROID_OC_MR1; |
| } |
| if (isExtension(operand.type)) { |
| NN_VALIDATE_EQ(operand.zeroPoint, 0) << "Operand of type " << operand.type |
| << " with a non-zero zeroPoint " << operand.zeroPoint; |
| return Version::ANDROID_Q; |
| } |
| NN_VALIDATE_FAIL() << "Invalid OperandType " << operand.type; |
| } |
| |
| Result<Version> validateOperandExtraParams(const Operand& operand) { |
| switch (operand.type) { |
| case OperandType::FLOAT32: |
| case OperandType::INT32: |
| case OperandType::UINT32: |
| case OperandType::TENSOR_FLOAT32: |
| case OperandType::TENSOR_INT32: |
| case OperandType::TENSOR_QUANT8_ASYMM: |
| case OperandType::BOOL: |
| case OperandType::TENSOR_QUANT16_SYMM: |
| case OperandType::TENSOR_FLOAT16: |
| case OperandType::TENSOR_BOOL8: |
| case OperandType::FLOAT16: |
| case OperandType::TENSOR_QUANT16_ASYMM: |
| case OperandType::TENSOR_QUANT8_SYMM: |
| case OperandType::TENSOR_QUANT8_ASYMM_SIGNED: |
| case OperandType::SUBGRAPH: |
| NN_VALIDATE(std::holds_alternative<Operand::NoParams>(operand.extraParams)) |
| << "Operand of type " << operand.type |
| << " has extraParams when there must be none"; |
| return Version::ANDROID_OC_MR1; |
| case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: { |
| NN_VALIDATE( |
| std::holds_alternative<Operand::SymmPerChannelQuantParams>(operand.extraParams)) |
| << "Operand of type " << operand.type |
| << " without a Channel Quantization params"; |
| const auto& channelQuant = |
| std::get<Operand::SymmPerChannelQuantParams>(operand.extraParams); |
| |
| const size_t count = operand.dimensions.size(); |
| NN_VALIDATE_LT(channelQuant.channelDim, count) |
| << "Operand of type " << operand.type |
| << " with an invalid channelQuant.channelDim " << channelQuant.channelDim |
| << ", must be valid dimension index in range [0, " << count << ")"; |
| const uint32_t expected = operand.dimensions[channelQuant.channelDim]; |
| NN_VALIDATE_EQ(channelQuant.scales.size(), expected) |
| << "Operand of type " << operand.type << " with a wrong-sized scales, expected " |
| << expected << " was " << channelQuant.scales.size(); |
| NN_VALIDATE_NE(expected, 0u) |
| << "Operand of type " << operand.type << " channel dimension " |
| << channelQuant.channelDim << " is underspecified (can't be 0)"; |
| for (uint32_t i = 0; i < expected; ++i) { |
| NN_VALIDATE_GT(channelQuant.scales[i], 0.0f) |
| << "Operand of type " << operand.type |
| << " with a non-positive value in scales[" << i |
| << "]=" << channelQuant.scales[i]; |
| } |
| return Version::ANDROID_Q; |
| } |
| case OperandType::OEM: |
| case OperandType::TENSOR_OEM_BYTE: |
| // No validation for OEM types. |
| return Version::ANDROID_OC_MR1; |
| } |
| if (isExtension(operand.type)) { |
| NN_VALIDATE(std::holds_alternative<Operand::NoParams>(operand.extraParams) || |
| std::holds_alternative<Operand::ExtensionParams>(operand.extraParams)) |
| << "Extension operand of type " << operand.type |
| << " must not have SymmPerChannelQuant extraParams"; |
| return Version::ANDROID_OC_MR1; |
| } |
| NN_VALIDATE_FAIL() << "Invalid OperandType " << operand.type; |
| } |
| |
| Result<Version> validateOperand(const Operand& operand, size_t operandValuesSize, |
| const std::vector<size_t>& poolSizes, |
| const std::vector<Model::Subgraph>& subgraphs, |
| std::vector<std::optional<Version>>* subgraphVersionCache) { |
| auto version = NN_TRY(validateOperandType(operand.type)); |
| version = combineVersions(version, NN_TRY(validateOperandLifeTime(operand))); |
| version = combineVersions(version, NN_TRY(validateOperandDimensions(operand))); |
| version = combineVersions(version, NN_TRY(validateOperandScale(operand))); |
| version = combineVersions(version, NN_TRY(validateOperandZeroPoint(operand))); |
| version = combineVersions(version, NN_TRY(validateOperandExtraParams(operand))); |
| version = combineVersions( |
| version, NN_TRY(validateOperandDataLocation(operand, operandValuesSize, poolSizes, |
| subgraphs, subgraphVersionCache))); |
| |
| // For constants, validate that the length is as expected. The other lifetimes |
| // expect the length to be 0. Don't validate for OEM types. |
| if (operand.lifetime == Operand::LifeTime::CONSTANT_REFERENCE || |
| operand.lifetime == Operand::LifeTime::CONSTANT_COPY || |
| operand.lifetime == Operand::LifeTime::POINTER) { |
| if (!isExtension(operand.type) && operand.type != OperandType::OEM && |
| operand.type != OperandType::TENSOR_OEM_BYTE) { |
| const auto expectedLength = getNonExtensionSize(operand).value(); |
| NN_VALIDATE_EQ(operand.location.length, expectedLength) |
| << "For operand " << operand.type << " expected a size of " << expectedLength |
| << " but got " << operand.location.length; |
| } |
| } |
| |
| return version; |
| } |
| |
| Result<std::vector<Version>> validateOperands( |
| const std::vector<Operand>& operands, size_t operandValuesSize, |
| const std::vector<size_t>& poolSizes, const std::vector<Model::Subgraph>& subgraphs, |
| std::vector<std::optional<Version>>* subgraphVersionCache) { |
| std::vector<Version> versions; |
| versions.reserve(operands.size()); |
| for (size_t i = 0; i < operands.size(); ++i) { |
| auto result = validateOperand(operands[i], operandValuesSize, poolSizes, subgraphs, |
| subgraphVersionCache); |
| if (!result.has_value()) { |
| return error() << std::move(result).error() << " for operand " << i; |
| } |
| versions.push_back(result.value()); |
| } |
| return versions; |
| } |
| |
| // Forward declaration. |
| Result<Version> validateOperationIncludingOperandVersions( |
| const Operation& operation, const std::vector<Operand>& operands, |
| const std::vector<Version>& operandVersions, const std::vector<Model::Subgraph>& subgraphs); |
| |
| Result<Version> validateOperations(const std::vector<Operation>& operations, |
| const std::vector<Operand>& operands, |
| const std::vector<Version>& operandVersions, |
| const std::vector<Model::Subgraph>& subgraphs) { |
| auto version = Version::ANDROID_OC_MR1; |
| for (size_t i = 0; i < operations.size(); ++i) { |
| auto result = validateOperationIncludingOperandVersions(operations[i], operands, |
| operandVersions, subgraphs); |
| if (!result.has_value()) { |
| return error() << std::move(result).error() << " for operation " << i; |
| } |
| version = combineVersions(version, result.value()); |
| } |
| return version; |
| } |
| |
| Result<Version> validateHandle(const Handle& handle) { |
| NN_VALIDATE(std::all_of(handle.fds.begin(), handle.fds.end(), |
| [](const base::unique_fd& fd) { return fd.ok(); })); |
| return Version::ANDROID_OC_MR1; |
| } |
| |
| Result<Version> validateSharedHandle(const SharedHandle& handle) { |
| NN_VALIDATE(handle != nullptr); |
| return validateHandle(*handle); |
| } |
| |
| Result<Version> validateMemory(const Memory::Ashmem& memory) { |
| NN_VALIDATE(memory.fd.ok()); |
| NN_VALIDATE_NE(memory.size, 0u); |
| return Version::ANDROID_OC_MR1; |
| } |
| |
| Result<Version> validateMemory(const Memory::Fd& memory) { |
| NN_VALIDATE(memory.fd.ok()); |
| NN_VALIDATE_NE(memory.size, 0u); |
| |
| // `prot` is allowed to be either PROT_NONE (which has a value of 0) or the bitwise OR of either |
| // PROT_READ or PROT_WRITE. If any other bits are set, the `prot` field is invalid. |
| constexpr int kAllowedBits = PROT_READ | PROT_WRITE; |
| NN_VALIDATE_EQ(memory.prot & ~kAllowedBits, 0); |
| |
| return Version::ANDROID_OC_MR1; |
| } |
| |
| Result<Version> validateMemory(const Memory::HardwareBuffer& memory) { |
| NN_VALIDATE(memory.handle.get() != nullptr); |
| return Version::ANDROID_Q; |
| } |
| |
| Result<Version> validateMemory(const Memory::Unknown& memory) { |
| NN_TRY(validateHandle(memory.handle)); |
| return Version::ANDROID_Q; |
| } |
| |
| Result<Version> validateSharedMemory(const SharedMemory& memory) { |
| NN_VALIDATE(memory != nullptr); |
| return std::visit([](const auto& x) { return validateMemory(x); }, memory->handle); |
| } |
| |
| Result<void> validateModelSubgraphInputOutputs(const std::vector<uint32_t>& indexes, |
| const std::vector<Operand>& operands, |
| Operand::LifeTime lifetime) { |
| const size_t operandCount = operands.size(); |
| for (uint32_t i : indexes) { |
| NN_VALIDATE_LT(i, operandCount) |
| << "Model " << lifetime << " input or output index out of range: " << i << "/" |
| << operandCount; |
| const Operand& operand = operands[i]; |
| NN_VALIDATE_EQ(operand.lifetime, lifetime) |
| << "Model " << lifetime << " operand " << i << " has lifetime of " |
| << operand.lifetime << " instead of the expected " << lifetime; |
| } |
| |
| std::vector<uint32_t> sortedIndexes = indexes; |
| std::sort(sortedIndexes.begin(), sortedIndexes.end()); |
| const auto iter = std::adjacent_find(sortedIndexes.begin(), sortedIndexes.end()); |
| NN_VALIDATE(iter == sortedIndexes.end()) |
| << "Model input or output occurs multiple times: " << *iter; |
| |
| for (size_t i = 0; i < operands.size(); ++i) { |
| if (operands[i].lifetime == lifetime) { |
| const auto containsIndex = [&sortedIndexes](size_t index) { |
| return binary_search(sortedIndexes.begin(), sortedIndexes.end(), index); |
| }; |
| NN_VALIDATE(containsIndex(i)) |
| << "Operand " << i << " marked as " << lifetime |
| << " but is not included in Model input or output indexes"; |
| } |
| } |
| |
| return {}; |
| } |
| |
| Result<void> validateExecutionOrder(const Model::Subgraph& subgraph) { |
| // Either the operand has a known value before model execution begins, or we've seen a writer |
| // for this operand while walking operands in execution order. Initialize to known operands. |
| std::vector<bool> operandValueKnown; |
| operandValueKnown.reserve(subgraph.operands.size()); |
| std::transform(subgraph.operands.begin(), subgraph.operands.end(), |
| std::back_inserter(operandValueKnown), [](const Operand& operand) { |
| return operand.lifetime != Operand::LifeTime::TEMPORARY_VARIABLE && |
| operand.lifetime != Operand::LifeTime::SUBGRAPH_OUTPUT; |
| }); |
| |
| // Validate that operations are sorted into execution order. |
| // |
| // If there is a cycle in the graph, the operations will not |
| // appear to be sorted into execution order: Some operation will |
| // have an input for which operandValueKnown[] is false. |
| for (size_t i = 0; i < subgraph.operations.size(); ++i) { |
| const auto& operation = subgraph.operations[i]; |
| |
| for (size_t j = 0; j < operation.inputs.size(); ++j) { |
| const uint32_t k = operation.inputs[j]; |
| NN_VALIDATE(operandValueKnown[k]) << "Operation " << i << " input " << j << " (operand " |
| << k << ") is read before it is written"; |
| } |
| |
| for (size_t j = 0; j < operation.outputs.size(); ++j) { |
| const uint32_t k = operation.outputs[j]; |
| // Assuming validateOperations() has not returned an error, we know that this output is |
| // TEMPORARY_VARIABLE or MODEL_OUTPUT, and so the only way operandValueKnown[k] can be |
| // true is if we've already seen a writer for this operand. |
| NN_VALIDATE(!operandValueKnown[k]) << "Operation " << i << " output " << j |
| << " (operand " << k << ") has already been written"; |
| operandValueKnown[k] = true; |
| } |
| } |
| |
| // Verify all operands are written. |
| for (size_t i = 0; i < subgraph.operands.size(); ++i) { |
| NN_VALIDATE(operandValueKnown[i]) << "Operand " << i << " is never written"; |
| } |
| |
| // TODO(b/77871786): verify that every operation has at least one output operand that is read? |
| |
| return {}; |
| } |
| |
| // Validate a subgraph, ensuring all subgraphs it depends on are also validated. |
| // |
| // `referencedIndex` is empty if the subgraph being validated is the main subgraph, otherwise it is |
| // the index of the referenced subgraph being validated. |
| // |
| // referenced[i] and (*subgraphVersionCache)[i] correspond to the same subgraph, and therefore |
| // `referenced` and `subgraphVersionCache` must have the same length. |
| Result<Version> validateModelSubgraph(const Model::Subgraph& subgraph, |
| std::optional<size_t> referencedIndex, |
| size_t operandValuesSize, |
| const std::vector<size_t>& poolSizes, |
| const std::vector<Model::Subgraph>& referenced, |
| std::vector<std::optional<Version>>* subgraphVersionCache) { |
| CHECK(subgraphVersionCache != nullptr); |
| CHECK_EQ(referenced.size(), subgraphVersionCache->size()); |
| |
| // Quickly return if the current subgraph has already been checked for its version. |
| if (referencedIndex.has_value()) { |
| if (auto version = subgraphVersionCache->at(*referencedIndex)) { |
| return *version; |
| } |
| } |
| |
| NN_VALIDATE(!subgraph.operands.empty()); |
| NN_VALIDATE(!subgraph.operations.empty()); |
| // TODO(b/173780642): Clarify whether subgraphs with no inputs or outputs are valid. |
| // NN_VALIDATE(!subgraph.inputIndexes.empty()); |
| // NN_VALIDATE(!subgraph.outputIndexes.empty()); |
| |
| const auto operandVersions = NN_TRY(validateOperands( |
| subgraph.operands, operandValuesSize, poolSizes, referenced, subgraphVersionCache)); |
| const auto operationsVersion = NN_TRY(validateOperations(subgraph.operations, subgraph.operands, |
| operandVersions, referenced)); |
| |
| // Accumulate the versions from all operands and operations. |
| const auto version = std::accumulate(operandVersions.begin(), operandVersions.end(), |
| operationsVersion, combineVersions); |
| |
| NN_TRY(validateModelSubgraphInputOutputs(subgraph.inputIndexes, subgraph.operands, |
| Operand::LifeTime::SUBGRAPH_INPUT)); |
| NN_TRY(validateModelSubgraphInputOutputs(subgraph.outputIndexes, subgraph.operands, |
| Operand::LifeTime::SUBGRAPH_OUTPUT)); |
| |
| NN_TRY(validateExecutionOrder(subgraph)); |
| |
| // Mark the current subgraph as having already been validated so the caller can quickly return |
| // if this subgraph is checked again. |
| if (referencedIndex.has_value()) { |
| subgraphVersionCache->at(*referencedIndex) = version; |
| } |
| return version; |
| } |
| |
| Result<Version> validateModelExtensionNamesAndPrefixes( |
| const std::vector<Model::ExtensionNameAndPrefix>& extensionNamesAndPrefixes) { |
| for (const auto& extensionNameAndPrefix : extensionNamesAndPrefixes) { |
| NN_VALIDATE(isValidExtensionName(extensionNameAndPrefix.name)); |
| } |
| |
| std::vector<std::reference_wrapper<const std::string>> names; |
| names.reserve(extensionNamesAndPrefixes.size()); |
| std::transform(extensionNamesAndPrefixes.begin(), extensionNamesAndPrefixes.end(), |
| std::back_inserter(names), |
| [](const Model::ExtensionNameAndPrefix& extensionNameAndPrefix) { |
| return std::cref(extensionNameAndPrefix.name); |
| }); |
| std::sort(names.begin(), names.end(), std::less<std::string>{}); |
| const auto nameIter = |
| std::adjacent_find(names.begin(), names.end(), std::equal_to<std::string>{}); |
| NN_VALIDATE(nameIter == names.end()) |
| << "ExtensionNamesAndPrefixes has duplicate name " << nameIter->get(); |
| |
| std::vector<uint16_t> types; |
| types.reserve(extensionNamesAndPrefixes.size()); |
| std::transform(extensionNamesAndPrefixes.begin(), extensionNamesAndPrefixes.end(), |
| std::back_inserter(types), |
| [](const Model::ExtensionNameAndPrefix& extensionNameAndPrefix) { |
| return extensionNameAndPrefix.prefix; |
| }); |
| std::sort(types.begin(), types.end()); |
| const auto typeIter = std::adjacent_find(types.begin(), types.end()); |
| NN_VALIDATE(typeIter == types.end()) |
| << "ExtensionNamesAndPrefixes has duplicate type " << *typeIter; |
| |
| const bool hasExtensions = !extensionNamesAndPrefixes.empty(); |
| return hasExtensions ? Version::ANDROID_Q : Version::ANDROID_OC_MR1; |
| } |
| |
| // Makes sure the model does not contain subgraph reference cycles. |
| // |
| // This function verifies that referencedSubgraphs[subgraphIndex] and any subgraphs it refences do |
| // not contain any reference cycles. `path` is used to keep track of which referenced subgraphs have |
| // already been visited in the current recursive reference path. `verified` is a cache to keep track |
| // of which referenced subgraphs have already been verified not to form reference cycles. |
| // |
| // referencedSubgraphs[i], (*path)[i], and (*verified)[i] all correspond to the same subgraph, and |
| // therefore `referencedSubgraphs`, `path`, and `verified` must all have the same length. |
| Result<void> checkNoReferenceCycles(const std::vector<Model::Subgraph>& referencedSubgraphs, |
| uint32_t subgraphIndex, std::vector<bool>* path, |
| std::vector<bool>* verified) { |
| CHECK(path != nullptr); |
| CHECK(verified != nullptr); |
| CHECK_EQ(referencedSubgraphs.size(), path->size()); |
| CHECK_EQ(referencedSubgraphs.size(), verified->size()); |
| const auto& subgraph = referencedSubgraphs.at(subgraphIndex); |
| |
| // Quickly return if the current subgraph has already been verified to have no reference cycles. |
| if ((*verified)[subgraphIndex]) { |
| return {}; |
| } |
| |
| // Add the current subgraph to the path (making sure that it is not already part of the path), |
| // and verify that all subgraphs this subgraph references do not contain cycles. The current |
| // subgraph is removed from the path only after all subgraphs this subgraph references have been |
| // checked. |
| NN_VALIDATE((*path)[subgraphIndex] == false) << "Model contains a circular subgraph reference"; |
| (*path)[subgraphIndex] = true; |
| for (const Operand& operand : subgraph.operands) { |
| if (operand.lifetime == Operand::LifeTime::SUBGRAPH) { |
| const uint32_t refSubgraphIndex = operand.location.offset; |
| NN_TRY(checkNoReferenceCycles(referencedSubgraphs, refSubgraphIndex, path, verified)); |
| } |
| } |
| (*path)[subgraphIndex] = false; |
| |
| // Mark the current subgraph as having already been verified so the caller can quickly return if |
| // this subgraph is checked again. |
| (*verified)[subgraphIndex] = true; |
| return {}; |
| } |
| |
| Result<void> checkNoReferenceCycles(const std::vector<Model::Subgraph>& referencedSubgraphs) { |
| const size_t count = referencedSubgraphs.size(); |
| std::vector<bool> path(count); |
| std::vector<bool> verified(count); |
| for (size_t i = 0; i < count; ++i) { |
| NN_TRY(checkNoReferenceCycles(referencedSubgraphs, i, &path, &verified)); |
| } |
| return {}; |
| } |
| |
| Result<Version> validateModel(const Model& model) { |
| auto version = NN_TRY(validateVector(model.pools, validateSharedMemory)); |
| version = combineVersions( |
| version, NN_TRY(validateModelExtensionNamesAndPrefixes(model.extensionNameToPrefix))); |
| |
| // Ignore relaxComputationFloat32toFloat16 version because in the worst case it makes the |
| // execution stricter. |
| |
| // Referenced models were introduced in Android R. |
| const bool hasReferencedModels = !model.referenced.empty(); |
| const auto referenceModelVersion = |
| hasReferencedModels ? Version::ANDROID_R : Version::ANDROID_OC_MR1; |
| version = combineVersions(version, referenceModelVersion); |
| |
| // Ensure that there are no cycles formed by the subgraphs. |
| NN_TRY(checkNoReferenceCycles(model.referenced)); |
| |
| // Get memory sizes. |
| const auto [operandValuesSize, poolSizes] = getMemorySizes(model); |
| |
| // Validate referenced subgraphs. |
| auto subgraphVersionCache = std::vector<std::optional<Version>>(model.referenced.size()); |
| for (size_t referencedIndex = 0; referencedIndex < model.referenced.size(); ++referencedIndex) { |
| const auto& subgraph = model.referenced[referencedIndex]; |
| const auto subgraphVersion = |
| NN_TRY(validateModelSubgraph(subgraph, referencedIndex, operandValuesSize, |
| poolSizes, model.referenced, &subgraphVersionCache)); |
| version = combineVersions(version, subgraphVersion); |
| } |
| |
| // Validate main subgraph. |
| const auto subgraphVersion = |
| NN_TRY(validateModelSubgraph(model.main, std::nullopt, operandValuesSize, poolSizes, |
| model.referenced, &subgraphVersionCache)); |
| version = combineVersions(version, subgraphVersion); |
| |
| return version; |
| } |
| |
| Result<Version> validateBufferDesc(const BufferDesc& bufferDesc) { |
| // An empty BufferDesc is the default value, so it is implicitly valid for all versions. |
| return bufferDesc.dimensions.empty() ? Version::ANDROID_OC_MR1 : Version::ANDROID_R; |
| } |
| |
| Result<Version> validateBufferRole(const BufferRole& bufferRole) { |
| NN_VALIDATE_GT(bufferRole.probability, 0.0f); |
| NN_VALIDATE_LE(bufferRole.probability, 1.0f); |
| return Version::ANDROID_R; |
| } |
| |
| Result<Version> validateRequestArgument(const Request::Argument& requestArgument, |
| const std::vector<size_t>& memorySizes, bool isOutput) { |
| const auto lifetime = requestArgument.lifetime; |
| const auto& location = requestArgument.location; |
| const auto& dimensions = requestArgument.dimensions; |
| |
| switch (lifetime) { |
| case Request::Argument::LifeTime::POOL: { |
| NN_VALIDATE(location.pointer == kNullptrVariant); |
| NN_VALIDATE_LT(location.poolIndex, memorySizes.size()); |
| // Do the addition using uint64_t to avoid potential wrap-around problems. |
| const auto lastPosition = |
| static_cast<uint64_t>(location.offset) + location.length + location.padding; |
| const auto memorySize = memorySizes[location.poolIndex]; |
| NN_VALIDATE_LE(lastPosition, memorySize); |
| if (memorySize > 0) { |
| // Must specify a positive length if the memory pool has a known size. |
| NN_VALIDATE_GT(location.length, 0u); |
| } |
| return Version::ANDROID_OC_MR1; |
| } |
| case Request::Argument::LifeTime::NO_VALUE: |
| NN_VALIDATE(location.pointer == kNullptrVariant); |
| NN_VALIDATE_EQ(location.poolIndex, 0u); |
| NN_VALIDATE_EQ(location.offset, 0u); |
| NN_VALIDATE_EQ(location.length, 0u); |
| NN_VALIDATE_EQ(location.padding, 0u); |
| NN_VALIDATE(dimensions.empty()); |
| return Version::ANDROID_OC_MR1; |
| case Request::Argument::LifeTime::POINTER: { |
| const bool isNullptr = |
| std::visit([](auto ptr) { return ptr == nullptr; }, location.pointer); |
| NN_VALIDATE(!isNullptr); |
| NN_VALIDATE_EQ(location.poolIndex, 0u); |
| NN_VALIDATE_EQ(location.offset, 0u); |
| NN_VALIDATE_NE(location.length, 0u); |
| if (isOutput) { |
| NN_VALIDATE(std::holds_alternative<void*>(location.pointer)); |
| } |
| return Version::ANDROID_OC_MR1; |
| } |
| } |
| NN_VALIDATE_FAIL() << "Invalid Request::Argument::LifeTime " << lifetime; |
| } |
| |
| Result<Version> validateRequestMemoryPool(const Request::MemoryPool& memoryPool) { |
| if (std::holds_alternative<Request::MemoryDomainToken>(memoryPool)) { |
| NN_VALIDATE(std::get<Request::MemoryDomainToken>(memoryPool) != kInvalidMemoryDomainToken); |
| return Version::ANDROID_R; |
| } |
| if (std::holds_alternative<SharedBuffer>(memoryPool)) { |
| NN_VALIDATE(std::get<SharedBuffer>(memoryPool) != nullptr); |
| return Version::ANDROID_R; |
| } |
| return validateSharedMemory(std::get<SharedMemory>(memoryPool)); |
| } |
| |
| Result<Version> validateRequest(const Request& request) { |
| auto version = NN_TRY(validateVector(request.pools, validateRequestMemoryPool)); |
| |
| // Get memory sizes. For IBuffer or MemoryDomainToken types, set size to 0. |
| std::vector<size_t> memorySizes; |
| memorySizes.reserve(request.pools.size()); |
| std::transform(request.pools.begin(), request.pools.end(), std::back_inserter(memorySizes), |
| [](const Request::MemoryPool& memoryPool) { |
| const auto* memory = std::get_if<SharedMemory>(&memoryPool); |
| return memory != nullptr ? getSize(*memory) : 0; |
| }); |
| |
| for (size_t i = 0; i < request.inputs.size(); ++i) { |
| const auto& input = request.inputs[i]; |
| auto result = validateRequestArgument(input, memorySizes, /*isOutput=*/false); |
| if (!result.has_value()) { |
| return error() << std::move(result).error() << " for input RequestArgument " << i; |
| } |
| version = combineVersions(version, result.value()); |
| } |
| for (size_t i = 0; i < request.outputs.size(); ++i) { |
| const auto& output = request.outputs[i]; |
| auto result = validateRequestArgument(output, memorySizes, /*isOutput=*/true); |
| if (!result.has_value()) { |
| return error() << std::move(result).error() << " for output RequestArgument " << i; |
| } |
| version = combineVersions(version, result.value()); |
| } |
| |
| return version; |
| } |
| |
| Result<Version> validateOptionalTimePoint(const OptionalTimePoint& optionalTimePoint) { |
| if (optionalTimePoint.has_value()) { |
| NN_VALIDATE_GE(optionalTimePoint->time_since_epoch().count(), 0); |
| } |
| // An omitted time point is the default value, so it is implicitly valid for all versions. |
| return !optionalTimePoint.has_value() ? Version::ANDROID_OC_MR1 : Version::ANDROID_R; |
| } |
| |
| Result<Version> validateOptionalTimeoutDuration(const OptionalDuration& optionalTimeoutDuration) { |
| if (optionalTimeoutDuration.has_value()) { |
| NN_VALIDATE_GE(optionalTimeoutDuration->count(), 0); |
| } |
| // An omitted duration is the default value, so it is implicitly valid for all versions. |
| return !optionalTimeoutDuration.has_value() ? Version::ANDROID_OC_MR1 : Version::ANDROID_R; |
| } |
| |
| Result<Version> validateCacheToken(const CacheToken& cacheToken) { |
| // A CacheToken of 0 is the default value, so it is implicitly valid for all versions. |
| constexpr auto kDefaultCacheToken = CacheToken{}; |
| return cacheToken == kDefaultCacheToken ? Version::ANDROID_OC_MR1 : Version::ANDROID_Q; |
| } |
| |
| Result<Version> validateSyncFence(const SyncFence& syncFence) { |
| // The absence of a sync fence is implicitly valid for all versions. |
| if (!syncFence.hasFd()) { |
| return Version::ANDROID_OC_MR1; |
| } |
| NN_VALIDATE_GE(syncFence.getFd(), 0); |
| return Version::ANDROID_R; |
| } |
| |
| Result<Version> validateRequestArgumentsForModel( |
| const std::vector<Request::Argument>& requestArguments, |
| const std::vector<uint32_t>& operandIndexes, const std::vector<Operand>& operands, |
| bool isOutput, bool allowUnspecifiedOutput) { |
| auto version = Version::ANDROID_OC_MR1; |
| // The request should specify as many arguments as were described in the model. |
| const std::string_view type = isOutput ? "output" : "input"; |
| const size_t requestArgumentCount = requestArguments.size(); |
| NN_VALIDATE_EQ(requestArgumentCount, operandIndexes.size()) |
| << "Request specifies " << requestArgumentCount << " " << type << "s but the model has " |
| << operandIndexes.size(); |
| for (size_t requestArgumentIndex = 0; requestArgumentIndex < requestArgumentCount; |
| requestArgumentIndex++) { |
| const Request::Argument& requestArgument = requestArguments[requestArgumentIndex]; |
| // Get the operand index for this argument. We extract it from the list |
| // that was provided in the call to ANeuralNetworksModel_identifyInputsAndOutputs. |
| // We assume in this function that the model has been validated already. |
| const uint32_t operandIndex = operandIndexes[requestArgumentIndex]; |
| const Operand& operand = operands[operandIndex]; |
| if (requestArgument.lifetime != Request::Argument::LifeTime::NO_VALUE) { |
| const bool isExtensionType = isExtension(operand.type); |
| // If the argument specified a dimension, validate it. |
| uint32_t modelRank = operand.dimensions.size(); |
| uint32_t requestRank = requestArgument.dimensions.size(); |
| if (requestRank == 0) { |
| // NOTE: validateRequestArguments cannot validate unknown tensor rank with |
| // extension operand type. |
| if (!isExtensionType && !isNonExtensionScalar(operand.type)) { |
| if (modelRank <= 0) { |
| NN_VALIDATE(isOutput) |
| << "Model has unknown input rank but the request does not " |
| "specify the rank."; |
| NN_VALIDATE(allowUnspecifiedOutput) |
| << "Model has unknown output rank and request does not specify it."; |
| // Unspecified output dimensions introduced in Android Q. |
| version = combineVersions(version, Version::ANDROID_Q); |
| } |
| } |
| // Validate that all the dimensions are specified in the model. |
| for (size_t i = 0; i < modelRank; i++) { |
| if (operand.dimensions[i] == 0) { |
| NN_VALIDATE(isOutput && allowUnspecifiedOutput) |
| << "Model has dimension " << i |
| << " set to 0 but the request does not specify the dimension."; |
| // Unspecified output dimensions introduced in Android Q. |
| version = combineVersions(version, Version::ANDROID_Q); |
| } |
| } |
| } else { |
| NN_VALIDATE(modelRank == 0 || requestRank == modelRank) |
| << "Request " << type << " " << requestArgumentIndex |
| << " has number of dimensions (" << requestRank |
| << ") different than the model's (" << modelRank << ")"; |
| for (size_t i = 0; i < requestRank; i++) { |
| NN_VALIDATE(modelRank == 0 || operand.dimensions[i] == 0 || |
| requestArgument.dimensions[i] == operand.dimensions[i]) |
| << "Request " << type << " " << requestArgumentIndex |
| << " has dimension " << i << " of " << requestArgument.dimensions[i] |
| << " different than the model's " << operand.dimensions[i]; |
| if (requestArgument.dimensions[i] == 0) { |
| NN_VALIDATE(isOutput && allowUnspecifiedOutput) |
| << "Request " << type << " " << requestArgumentIndex |
| << " has dimension " << i << " of zero"; |
| // Unspecified output dimensions introduced in Android Q. |
| version = combineVersions(version, Version::ANDROID_Q); |
| } |
| } |
| } |
| // NOTE: validateRequestArguments cannot validate DataLocation::length |
| // with extension operand type. |
| if (!isExtensionType && requestArgument.location.length != 0) { |
| const auto dimensions = |
| NN_TRY(combineDimensions(operand.dimensions, requestArgument.dimensions)); |
| const size_t expectedLength = getNonExtensionSize(operand.type, dimensions).value(); |
| if (expectedLength != 0) { |
| NN_VALIDATE_EQ(requestArgument.location.length, expectedLength) |
| << "Request " << type << " " << requestArgumentIndex |
| << " expected a size of " << expectedLength << " but got " |
| << requestArgument.location.length; |
| } |
| } |
| } |
| } |
| return version; |
| } |
| |
| Result<Version> validateRequestForModelImpl(const Request& request, const Model& model, |
| bool allowUnspecifiedOutput) { |
| auto version = NN_TRY(validateRequest(request)); |
| version = combineVersions(version, NN_TRY(validateModel(model))); |
| version = combineVersions(version, |
| NN_TRY(validateRequestArgumentsForModel( |
| request.inputs, model.main.inputIndexes, model.main.operands, |
| /*isOutput=*/false, /*allowUnspecifiedOutput=*/true))); |
| version = combineVersions( |
| version, NN_TRY(validateRequestArgumentsForModel( |
| request.outputs, model.main.outputIndexes, model.main.operands, |
| /*isOutput=*/true, allowUnspecifiedOutput))); |
| return version; |
| } |
| |
| Result<Version> validateMemoryDescImpl( |
| const BufferDesc& desc, const std::vector<SharedPreparedModel>& preparedModels, |
| const std::vector<BufferRole>& inputRoles, const std::vector<BufferRole>& outputRoles, |
| const std::function<const Model*(const SharedPreparedModel&)>& getModel, |
| std::set<PreparedModelRole>* preparedModelRoles, Operand* combinedOperand) { |
| NN_VALIDATE(!preparedModels.empty()); |
| NN_VALIDATE(!inputRoles.empty() || !outputRoles.empty()); |
| |
| std::set<PreparedModelRole> roles; |
| std::vector<nn::Operand> operands; |
| operands.reserve(inputRoles.size() + outputRoles.size()); |
| for (const auto& role : inputRoles) { |
| NN_VALIDATE_LT(role.modelIndex, preparedModels.size()); |
| const auto& preparedModel = preparedModels[role.modelIndex]; |
| NN_VALIDATE(preparedModel != nullptr); |
| const auto* model = getModel(preparedModel); |
| NN_VALIDATE(model != nullptr); |
| const auto& inputIndexes = model->main.inputIndexes; |
| NN_VALIDATE_LT(role.ioIndex, inputIndexes.size()); |
| NN_VALIDATE_GT(role.probability, 0.0f); |
| NN_VALIDATE_LE(role.probability, 1.0f); |
| const auto [it, success] = roles.emplace(preparedModel.get(), IOType::INPUT, role.ioIndex); |
| NN_VALIDATE(success); |
| operands.push_back(model->main.operands[inputIndexes[role.ioIndex]]); |
| } |
| for (const auto& role : outputRoles) { |
| NN_VALIDATE_LT(role.modelIndex, preparedModels.size()); |
| const auto& preparedModel = preparedModels[role.modelIndex]; |
| NN_VALIDATE(preparedModel != nullptr); |
| const auto* model = getModel(preparedModel); |
| NN_VALIDATE(model != nullptr); |
| const auto& outputIndexes = model->main.outputIndexes; |
| NN_VALIDATE_LT(role.ioIndex, outputIndexes.size()); |
| NN_VALIDATE_GT(role.probability, 0.0f); |
| NN_VALIDATE_LE(role.probability, 1.0f); |
| const auto [it, success] = roles.emplace(preparedModel.get(), IOType::OUTPUT, role.ioIndex); |
| NN_VALIDATE(success); |
| operands.push_back(model->main.operands[outputIndexes[role.ioIndex]]); |
| } |
| |
| CHECK(!operands.empty()); |
| const auto opType = operands.front().type; |
| |
| Dimensions dimensions = desc.dimensions; |
| for (const auto& operand : operands) { |
| NN_VALIDATE_EQ(operand.type, opType) << operand.type << " vs " << operands.front().type; |
| NN_VALIDATE_EQ(operand.scale, operands.front().scale); |
| NN_VALIDATE_EQ(operand.zeroPoint, operands.front().zeroPoint); |
| // NOTE: validateMemoryDesc cannot validate extra parameters for extension operand type. |
| if (!isExtension(opType)) { |
| NN_VALIDATE_EQ(operand.extraParams, operands.front().extraParams) |
| << operand.extraParams << " vs " << operands.front().extraParams; |
| } |
| dimensions = NN_TRY(combineDimensions(dimensions, operand.dimensions)); |
| } |
| |
| // NOTE: validateMemoryDesc cannot validate scalar dimensions with extension operand type. |
| if (!isExtension(opType)) { |
| NN_VALIDATE(!isNonExtensionScalar(opType) || dimensions.empty()) |
| << "invalid dimensions with scalar operand type."; |
| } |
| |
| if (preparedModelRoles != nullptr) { |
| *preparedModelRoles = std::move(roles); |
| } |
| if (combinedOperand != nullptr) { |
| *combinedOperand = operands.front(); |
| combinedOperand->dimensions = dimensions; |
| } |
| return Version::ANDROID_R; |
| } |
| |
| class OperationValidationContext : public IOperationValidationContext { |
| DISALLOW_IMPLICIT_CONSTRUCTORS(OperationValidationContext); |
| |
| public: |
| OperationValidationContext(const char* operationName, const std::vector<uint32_t>& inputIndexes, |
| const std::vector<uint32_t>& outputIndexes, |
| const std::vector<Operand>& operands) |
| : operationName(operationName), |
| inputIndexes(inputIndexes), |
| outputIndexes(outputIndexes), |
| operands(operands) {} |
| |
| const char* getOperationName() const override; |
| |
| uint32_t getNumInputs() const override; |
| OperandType getInputType(uint32_t index) const override; |
| Shape getInputShape(uint32_t index) const override; |
| const Operand::ExtraParams& getInputExtraParams(uint32_t index) const override; |
| |
| uint32_t getNumOutputs() const override; |
| OperandType getOutputType(uint32_t index) const override; |
| Shape getOutputShape(uint32_t index) const override; |
| |
| private: |
| const Operand* getInputOperand(uint32_t index) const; |
| const Operand* getOutputOperand(uint32_t index) const; |
| |
| const char* operationName; |
| const std::vector<uint32_t>& inputIndexes; |
| const std::vector<uint32_t>& outputIndexes; |
| const std::vector<Operand>& operands; |
| }; |
| |
| const char* OperationValidationContext::getOperationName() const { |
| return operationName; |
| } |
| |
| const Operand* OperationValidationContext::getInputOperand(uint32_t index) const { |
| return &operands.at(inputIndexes.at(index)); |
| } |
| |
| const Operand* OperationValidationContext::getOutputOperand(uint32_t index) const { |
| return &operands.at(outputIndexes.at(index)); |
| } |
| |
| uint32_t OperationValidationContext::getNumInputs() const { |
| auto count = inputIndexes.size(); |
| CHECK_LE(count, std::numeric_limits<uint32_t>::max()); |
| return static_cast<uint32_t>(count); |
| } |
| |
| uint32_t OperationValidationContext::getNumOutputs() const { |
| auto count = outputIndexes.size(); |
| CHECK_LE(count, std::numeric_limits<uint32_t>::max()); |
| return static_cast<uint32_t>(count); |
| } |
| |
| OperandType OperationValidationContext::getInputType(uint32_t index) const { |
| return getInputOperand(index)->type; |
| } |
| |
| Shape OperationValidationContext::getInputShape(uint32_t index) const { |
| const Operand* operand = getInputOperand(index); |
| return {operand->type, operand->dimensions, operand->scale, operand->zeroPoint, |
| operand->extraParams}; |
| } |
| |
| const Operand::ExtraParams& OperationValidationContext::getInputExtraParams(uint32_t index) const { |
| return getInputOperand(index)->extraParams; |
| } |
| |
| OperandType OperationValidationContext::getOutputType(uint32_t index) const { |
| return getOutputOperand(index)->type; |
| } |
| |
| Shape OperationValidationContext::getOutputShape(uint32_t index) const { |
| const Operand* operand = getOutputOperand(index); |
| return {operand->type, operand->dimensions, operand->scale, operand->zeroPoint, |
| operand->extraParams}; |
| } |
| |
| // TODO(b/169345292): reduce the duplicate validation here |
| |
| Result<void> validateOperandSymmPerChannelQuantParamsImpl( |
| const Operand& operand, const Operand::SymmPerChannelQuantParams& channelQuant, |
| const char* tag) { |
| if (operand.type != OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) { |
| NN_VALIDATE_FAIL(); |
| } |
| |
| NN_VALIDATE_LT(channelQuant.channelDim, operand.dimensions.size()) << tag; |
| NN_VALIDATE(!channelQuant.scales.empty()) << tag; |
| NN_VALIDATE_EQ(channelQuant.scales.size(), operand.dimensions[channelQuant.channelDim]) << tag; |
| NN_VALIDATE_NE(operand.dimensions[channelQuant.channelDim], 0u) |
| << tag << " channel dimension " << channelQuant.channelDim << " is underspecified"; |
| for (uint32_t i = 0; i < operand.dimensions[channelQuant.channelDim]; i++) { |
| NN_VALIDATE_GT(channelQuant.scales[i], 0.0f) << tag << " invalid scaleArray[" << i << "]"; |
| } |
| return {}; |
| } |
| |
| Result<void> validateScalarDimensions(const Operand& type, const char* tag) { |
| NN_VALIDATE(type.dimensions.empty()) << tag << " invalid dimensions for scalar type"; |
| return {}; |
| } |
| |
| Result<void> validateQuant8AsymmParams(const Operand& type, const char* tag) { |
| NN_VALIDATE(0 <= type.zeroPoint && type.zeroPoint <= 255) |
| << tag << " invalid zeroPoint: " << type.zeroPoint; |
| NN_VALIDATE_GT(type.scale, 0.0f) << tag << " invalid scale"; |
| return {}; |
| } |
| |
| Result<void> validateQuant8AsymmSignedParams(const Operand& type, const char* tag) { |
| NN_VALIDATE(-128 <= type.zeroPoint && type.zeroPoint <= 127) |
| << tag << " invalid zeroPoint: " << type.zeroPoint; |
| NN_VALIDATE_GT(type.scale, 0.0f) << tag << " invalid scale"; |
| return {}; |
| } |
| |
| Result<void> validateQuant8SymmParams(const Operand& type, const char* tag) { |
| NN_VALIDATE_EQ(type.zeroPoint, 0) << tag << " invalid zeroPoint: " << type.zeroPoint; |
| NN_VALIDATE_GT(type.scale, 0.0f) << tag << " invalid scale"; |
| return {}; |
| } |
| |
| Result<void> validateQuant16AsymmParams(const Operand& type, const char* tag) { |
| NN_VALIDATE(0 <= type.zeroPoint && type.zeroPoint <= 65535) |
| << tag << " invalid zeroPoint: " << type.zeroPoint; |
| NN_VALIDATE_GT(type.scale, 0.0f) << tag << " invalid scale"; |
| return {}; |
| } |
| |
| Result<void> validateQuantSymmParams(const Operand& type, const char* tag) { |
| NN_VALIDATE_EQ(type.zeroPoint, 0) << tag << " zeroPoint is not zero"; |
| NN_VALIDATE_GT(type.scale, 0.0f) << tag << " invalid scale"; |
| return {}; |
| } |
| |
| Result<void> validateNoQuantParams(const Operand& type, const char* tag) { |
| NN_VALIDATE_EQ(type.zeroPoint, 0) << tag << " zeroPoint is not zero"; |
| NN_VALIDATE_EQ(type.scale, 0.0f) << tag << " scale is not zero"; |
| return {}; |
| } |
| |
| Result<void> validateTensorDimensions( |
| const Operand& type, const Extension::OperandTypeInformation* extensionOperandTypeInfo, |
| const char* tag, bool allowPartial) { |
| if (!allowPartial) { |
| NN_VALIDATE(!type.dimensions.empty()) << tag << " invalid operand dimensions"; |
| } |
| uint64_t size = isExtension(type.type) ? extensionOperandTypeInfo->byteSize |
| : getNonExtensionSize(type.type); |
| constexpr uint64_t kMaxSize = std::numeric_limits<uint32_t>::max(); |
| for (size_t i = 0; i < type.dimensions.size(); i++) { |
| if (!allowPartial) { |
| NN_VALIDATE_NE(type.dimensions[i], 0u) << tag << " invalid operand dimensions"; |
| } |
| if (type.dimensions[i] != 0) { |
| size *= type.dimensions[i]; |
| NN_VALIDATE_LE(size, kMaxSize) << tag << " operand byte size exceeds " << kMaxSize; |
| } |
| } |
| return {}; |
| } |
| |
| Result<void> validateOperandTypeImpl( |
| const Operand& type, |
| const Extension::OperandTypeInformation* const extensionOperandTypeInfo, const char* tag, |
| bool allowPartial) { |
| if (isExtension(type.type)) { |
| NN_VALIDATE(extensionOperandTypeInfo != nullptr); |
| if (extensionOperandTypeInfo->isTensor) { |
| NN_TRY(validateTensorDimensions(type, extensionOperandTypeInfo, tag, allowPartial)); |
| } else { |
| NN_TRY(validateScalarDimensions(type, tag)); |
| } |
| return validateNoQuantParams(type, tag); |
| } |
| |
| NN_VALIDATE(extensionOperandTypeInfo == nullptr); |
| NN_TRY(validateOperandType(type.type)); |
| |
| if (isNonExtensionScalar(type.type)) { |
| NN_TRY(validateScalarDimensions(type, tag)); |
| if (type.type != OperandType::OEM) { // Historically, we have allowed OEM types |
| // to use quantization parameters. |
| NN_TRY(validateNoQuantParams(type, tag)); |
| } |
| } else { |
| NN_TRY(validateTensorDimensions(type, extensionOperandTypeInfo, tag, allowPartial)); |
| if (type.type == OperandType::TENSOR_QUANT8_ASYMM) { |
| NN_TRY(validateQuant8AsymmParams(type, tag)); |
| } else if (type.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { |
| NN_TRY(validateQuant8AsymmSignedParams(type, tag)); |
| } else if (type.type == OperandType::TENSOR_QUANT8_SYMM) { |
| NN_TRY(validateQuant8SymmParams(type, tag)); |
| } else if (type.type == OperandType::TENSOR_QUANT16_ASYMM) { |
| NN_TRY(validateQuant16AsymmParams(type, tag)); |
| } else if (type.type == OperandType::TENSOR_QUANT16_SYMM) { |
| NN_TRY(validateQuantSymmParams(type, tag)); |
| } else if (type.type == OperandType::TENSOR_INT32 || |
| type.type == OperandType::TENSOR_OEM_BYTE) { |
| // TODO(b/119869082): TENSOR_INT32 should not use quantization parameters. |
| // Historically, we have allowed OEM types to use quantization parameters. |
| } else { |
| NN_TRY(validateNoQuantParams(type, tag)); |
| } |
| } |
| |
| return {}; |
| } |
| |
| Result<void> validateOperandListImpl(const std::vector<uint32_t>& list, size_t operandCount, |
| const char* tag) { |
| for (size_t i = 0; i < list.size(); i++) { |
| NN_VALIDATE_LT(list[i], operandCount) << tag << " invalid operand index at " << i << " = " |
| << list[i] << ", operandCount " << operandCount; |
| } |
| return {}; |
| } |
| |
| Result<void> validateOperationOperandTypes(const std::vector<Operand>& operands, |
| const std::vector<uint32_t>& inputIndexes, |
| const std::vector<OperandType>& inExpectedTypes, |
| const std::vector<uint32_t>& outputIndexes, |
| const std::vector<OperandType>& outExpectedInTypes) { |
| NN_VALIDATE_EQ(inputIndexes.size(), inExpectedTypes.size()) |
| << "Wrong operand count: expected " << inputIndexes.size() << " inputs, got " |
| << inputIndexes.size() << " inputs"; |
| NN_VALIDATE_EQ(outputIndexes.size(), outExpectedInTypes.size()) |
| << "Wrong operand count: expected " << outputIndexes.size() << " outputs, got " |
| << outputIndexes.size() << " outputs"; |
| for (size_t i = 0; i < inputIndexes.size(); i++) { |
| NN_VALIDATE_EQ(operands[inputIndexes[i]].type, inExpectedTypes[i]) |
| << "Invalid input tensor type " << operands[inputIndexes[i]].type << " for input " |
| << i << ", expected " << inExpectedTypes[i]; |
| } |
| for (size_t i = 0; i < outputIndexes.size(); i++) { |
| NN_VALIDATE_EQ(operands[outputIndexes[i]].type, outExpectedInTypes[i]) |
| << "Invalid output tensor type " << operands[outputIndexes[i]].type << " for input " |
| << i << ", expected " << outExpectedInTypes[i]; |
| } |
| |
| return {}; |
| } |
| |
| Result<void> validateSubgraphReference(const std::vector<Model::Subgraph>& subgraphs, |
| const Operand& modelOperand) { |
| NN_VALIDATE_EQ(modelOperand.type, OperandType::SUBGRAPH) |
| << "Unexpected operand type: " << modelOperand.type; |
| NN_VALIDATE_LT(modelOperand.location.offset, subgraphs.size()) << "Invalid subgraph reference"; |
| return {}; |
| } |
| const Model::Subgraph& getSubgraph(const std::vector<Model::Subgraph>& subgraphs, |
| const Operand& modelOperand) { |
| return subgraphs.at(modelOperand.location.offset); |
| } |
| uint32_t getInputCount(const std::vector<Model::Subgraph>& subgraphs, const Operand& modelOperand) { |
| return getSubgraph(subgraphs, modelOperand).inputIndexes.size(); |
| } |
| uint32_t getOutputCount(const std::vector<Model::Subgraph>& subgraphs, |
| const Operand& modelOperand) { |
| return getSubgraph(subgraphs, modelOperand).outputIndexes.size(); |
| } |
| const Operand& getInputOperand(const std::vector<Model::Subgraph>& subgraphs, |
| const Operand& modelOperand, uint32_t index) { |
| const Model::Subgraph& subgraph = getSubgraph(subgraphs, modelOperand); |
| return subgraph.operands.at(subgraph.inputIndexes.at(index)); |
| } |
| const Operand& getOutputOperand(const std::vector<Model::Subgraph>& subgraphs, |
| const Operand& modelOperand, uint32_t index) { |
| const Model::Subgraph& subgraph = getSubgraph(subgraphs, modelOperand); |
| return subgraph.operands.at(subgraph.outputIndexes.at(index)); |
| } |
| |
| // Checks if two operands have the same types, ranks (if specified), dimensions |
| // (if specified), scales, zeroPoints, and extraParams. |
| Result<void> compatible(const Operand& a, const Operand& b) { |
| NN_VALIDATE_EQ(a.type, b.type) << a.type << " != " << b.type; |
| if (!a.dimensions.empty() && !b.dimensions.empty()) { |
| NN_VALIDATE_EQ(a.dimensions.size(), b.dimensions.size()) << "Incompatible dimensions"; |
| for (uint32_t i = 0, n = a.dimensions.size(); i < n; ++i) { |
| if (a.dimensions[i] != 0 && b.dimensions[i] != 0) { |
| NN_VALIDATE_EQ(a.dimensions[i], b.dimensions[i]) << "Incompatible dimensions"; |
| } |
| } |
| } |
| NN_VALIDATE_EQ(a.scale, b.scale); |
| NN_VALIDATE_EQ(a.zeroPoint, b.zeroPoint); |
| NN_VALIDATE_EQ(a.extraParams, b.extraParams) << a.extraParams << " != " << b.extraParams; |
| return {}; |
| } |
| |
| Result<void> validateConditionOperand(const Operand& operand) { |
| NN_VALIDATE_EQ(operand.type, OperandType::TENSOR_BOOL8) |
| << "Unexpected condition operand type: " << operand.type; |
| NN_VALIDATE_EQ(operand.dimensions.size(), 1u) << "Condition operand must be a singleton"; |
| NN_VALIDATE_EQ(operand.dimensions[0], 1u) << "Condition operand must be a singleton"; |
| return {}; |
| } |
| |
| Result<Version> validateIfOperation(const std::vector<uint32_t>& inputs, |
| const std::vector<uint32_t>& outputs, |
| const std::vector<Operand>& operands, |
| const std::vector<Model::Subgraph>& subgraphs) { |
| namespace op = operation_if; |
| NN_VALIDATE_GE(inputs.size(), 3u) << "IF must have at least 3 inputs"; |
| NN_VALIDATE_GE(outputs.size(), 1u) << "IF must have at least 1 output"; |
| auto validateBranchOperand = [&](const Operand& branchModelOperand) -> Result<void> { |
| auto result = validateSubgraphReference(subgraphs, branchModelOperand); |
| if (!result.has_value()) { |
| return error() << std::move(result).error() |
| << " -- Operand is not a valid subgraph reference"; |
| } |
| const uint32_t branchModelInputCount = getInputCount(subgraphs, branchModelOperand); |
| const uint32_t branchModelOutputCount = getOutputCount(subgraphs, branchModelOperand); |
| NN_VALIDATE_EQ(inputs.size(), op::kFirstInput + branchModelInputCount); |
| NN_VALIDATE_EQ(outputs.size(), branchModelOutputCount); |
| for (uint32_t i = 0; i < branchModelInputCount; ++i) { |
| const Operand& innerOperand = getInputOperand(subgraphs, branchModelOperand, i); |
| const Operand& outerOperand = operands[inputs[op::kFirstInput + i]]; |
| NN_TRY(compatible(innerOperand, outerOperand)); |
| } |
| for (uint32_t i = 0; i < branchModelOutputCount; ++i) { |
| const Operand& innerOperand = getOutputOperand(subgraphs, branchModelOperand, i); |
| const Operand& outerOperand = operands[outputs[i]]; |
| NN_TRY(compatible(innerOperand, outerOperand)); |
| } |
| return {}; |
| }; |
| auto result = validateConditionOperand(operands[inputs[op::kCondBoolOperand]]); |
| if (!result.has_value()) { |
| return error() << std::move(result).error() << " for IF condition operand"; |
| } |
| result = validateBranchOperand(operands[inputs[op::kThenModelOperand]]); |
| if (!result.has_value()) { |
| return error() << std::move(result).error() << " for IF then model"; |
| } |
| result = validateBranchOperand(operands[inputs[op::kElseModelOperand]]); |
| if (!result.has_value()) { |
| return error() << std::move(result).error() << " for IF else model"; |
| } |
| return Version::ANDROID_R; |
| } |
| |
| Result<Version> validateControlFlowOperandUnknownSize(const Operand& operand) { |
| if (!isExtension(operand.type) && getNonExtensionSize(operand).value() == 0) { |
| // 1.3 HAL (corresponding to Version::ANDROID_R) does not support CF operations with |
| // operands of unknown size. See http://b/132458982#comment63. |
| return Version::CURRENT_RUNTIME; |
| } |
| return Version::ANDROID_R; |
| } |
| |
| Result<Version> validateWhileOperation(const std::vector<uint32_t>& inputs, |
| const std::vector<uint32_t>& outputs, |
| const std::vector<Operand>& operands, |
| const std::vector<Model::Subgraph>& subgraphs) { |
| // Let the loop have |
| // - m >= 1 input-output operands, |
| // - k >= 0 state-only operands, and |
| // - n >= 0 input-only operands. |
| // Then |
| // - the WHILE loop operation has (2 + m + k + n) inputs and m outputs. |
| // - the condition model has (m + k + n) inputs and 1 output. |
| // - the body model has (m + k + n) inputs and (m + k) outputs. |
| namespace op = operation_while; |
| NN_VALIDATE_GE(inputs.size(), 3u) << "WHILE must have at least 3 inputs"; |
| NN_VALIDATE_GE(outputs.size(), 1u) << "WHILE must have at least 1 output"; |
| auto validateCondOperand = [&](const Operand& condModelOperand) -> Result<Version> { |
| Version version = Version::ANDROID_R; |
| auto result = validateSubgraphReference(subgraphs, condModelOperand); |
| if (!result.has_value()) { |
| return error() << std::move(result).error() |
| << " -- Operand is not a valid subgraph reference"; |
| } |
| const uint32_t condModelInputCount = getInputCount(subgraphs, condModelOperand); |
| const uint32_t condModelOutputCount = getOutputCount(subgraphs, condModelOperand); |
| NN_VALIDATE_EQ(inputs.size(), op::kFirstInput + condModelInputCount); |
| NN_VALIDATE_EQ(condModelOutputCount, 1u); |
| for (uint32_t i = 0; i < condModelInputCount; ++i) { |
| const Operand& innerOperand = getInputOperand(subgraphs, condModelOperand, i); |
| const Operand& outerOperand = operands[inputs[op::kFirstInput + i]]; |
| NN_TRY(compatible(innerOperand, outerOperand)); |
| version = combineVersions(version, |
| NN_TRY(validateControlFlowOperandUnknownSize(innerOperand))); |
| version = combineVersions(version, |
| NN_TRY(validateControlFlowOperandUnknownSize(outerOperand))); |
| } |
| NN_TRY(validateConditionOperand(getOutputOperand(subgraphs, condModelOperand, 0))); |
| return version; |
| }; |
| auto validateBodyOperand = [&](const Operand& bodyModelOperand) -> Result<Version> { |
| Version version = Version::ANDROID_R; |
| auto result = validateSubgraphReference(subgraphs, bodyModelOperand); |
| if (!result.has_value()) { |
| return error() << std::move(result).error() |
| << " -- Operand is not a valid subgraph reference"; |
| } |
| const uint32_t bodyModelInputCount = getInputCount(subgraphs, bodyModelOperand); |
| const uint32_t bodyModelOutputCount = getOutputCount(subgraphs, bodyModelOperand); |
| NN_VALIDATE_EQ(inputs.size(), op::kFirstInput + bodyModelInputCount); |
| NN_VALIDATE_GE(bodyModelOutputCount, outputs.size()); |
| NN_VALIDATE_GE(bodyModelInputCount, bodyModelOutputCount); |
| const uint32_t inputOutputCount = outputs.size(); |
| const uint32_t stateOnlyCount = bodyModelOutputCount - inputOutputCount; |
| const uint32_t inputOnlyCount = bodyModelInputCount - bodyModelOutputCount; |
| for (uint32_t i = 0, n = inputOutputCount + stateOnlyCount + inputOnlyCount; i < n; ++i) { |
| const Operand& innerOperand = getInputOperand(subgraphs, bodyModelOperand, i); |
| const Operand& outerOperand = operands[inputs[op::kFirstInput + i]]; |
| NN_TRY(compatible(innerOperand, outerOperand)); |
| version = combineVersions(version, |
| NN_TRY(validateControlFlowOperandUnknownSize(innerOperand))); |
| version = combineVersions(version, |
| NN_TRY(validateControlFlowOperandUnknownSize(outerOperand))); |
| } |
| for (uint32_t i = 0; i < inputOutputCount; ++i) { |
| const Operand& innerOperand = getOutputOperand(subgraphs, bodyModelOperand, i); |
| const Operand& outerOperand = operands[outputs[i]]; |
| NN_TRY(compatible(innerOperand, outerOperand)); |
| version = combineVersions(version, |
| NN_TRY(validateControlFlowOperandUnknownSize(outerOperand))); |
| } |
| for (uint32_t i = 0, n = inputOutputCount + stateOnlyCount; i < n; ++i) { |
| const Operand& inputOperand = getInputOperand(subgraphs, bodyModelOperand, i); |
| const Operand& outputOperand = getOutputOperand(subgraphs, bodyModelOperand, i); |
| NN_TRY(compatible(inputOperand, outputOperand)); |
| version = combineVersions(version, |
| NN_TRY(validateControlFlowOperandUnknownSize(outputOperand))); |
| } |
| return version; |
| }; |
| auto result = validateCondOperand(operands[inputs[op::kCondModelOperand]]); |
| if (!result.has_value()) { |
| return error() << std::move(result).error() << " for WHILE condition model"; |
| } |
| auto version = result.value(); |
| result = validateBodyOperand(operands[inputs[op::kBodyModelOperand]]); |
| if (!result.has_value()) { |
| return error() << std::move(result).error() << " for WHILE body model"; |
| } |
| version = combineVersions(version, result.value()); |
| return version; |
| } |
| |
| Result<Version> validateOperationButNotOperandsImpl(const Operation& operation, |
| const std::vector<Operand>& operands, |
| const std::vector<Model::Subgraph>& subgraphs) { |
| const auto opType = operation.type; |
| const auto& inputIndexes = operation.inputs; |
| const auto& outputIndexes = operation.outputs; |
| |
| NN_TRY(validateOperandListImpl(inputIndexes, operands.size(), |
| "ANeuralNetworksModel_addOperation inputs")); |
| NN_TRY(validateOperandListImpl(outputIndexes, operands.size(), |
| "ANeuralNetworksModel_addOperation outputs")); |
| |
| if (isExtension(opType)) { |
| // There is no other validation we can do for an extension operation. |
| return Version::ANDROID_Q; |
| } |
| |
| auto invalidInOutNumberMessage = [opType, &inputIndexes, &outputIndexes](int expIn, |
| int expOut) { |
| std::ostringstream os; |
| os << "Invalid number of input operands (" << inputIndexes.size() << ", expected " << expIn |
| << ") or output operands (" << outputIndexes.size() << ", expected " << expOut |
| << ") for operation " << opType; |
| return os.str(); |
| }; |
| |
| switch (opType) { |
| case OperationType::OEM_OPERATION: { |
| return Version::ANDROID_OC_MR1; |
| } |
| case OperationType::RESHAPE: { |
| NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1) |
| << invalidInOutNumberMessage(2, 1); |
| auto inputType = operands[inputIndexes[0]].type; |
| Version version; |
| std::vector<OperandType> inExpectedTypes; |
| std::vector<OperandType> outExpectedTypes; |
| if (inputType == OperandType::TENSOR_FLOAT32) { |
| version = Version::ANDROID_OC_MR1; |
| inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_INT32}; |
| outExpectedTypes = {OperandType::TENSOR_FLOAT32}; |
| } else if (inputType == OperandType::TENSOR_FLOAT16) { |
| version = Version::ANDROID_Q; |
| inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_INT32}; |
| outExpectedTypes = {OperandType::TENSOR_FLOAT16}; |
| } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { |
| version = Version::ANDROID_OC_MR1; |
| inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_INT32}; |
| outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM}; |
| } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { |
| version = Version::ANDROID_R; |
| inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED, |
| OperandType::TENSOR_INT32}; |
| outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED}; |
| } else { |
| NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; |
| } |
| const auto inputRank = operands[inputIndexes[0]].dimensions.size(); |
| NN_VALIDATE_LE(inputRank, 4u) |
| << "Unsupported input tensor rank for operation " << opType; |
| NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, |
| outputIndexes, outExpectedTypes)); |
| return version; |
| } |
| case OperationType::DEPTH_TO_SPACE: { |
| NN_VALIDATE((inputIndexes.size() == 3 || inputIndexes.size() == 2) && |
| outputIndexes.size() == 1) |
| << "Invalid number of input operands (" << inputIndexes.size() |
| << ", expected 3 or 2) or output operands (" << outputIndexes.size() |
| << ", expected 1) for operation " << opType; |
| auto inputType = operands[inputIndexes[0]].type; |
| Version version; |
| std::vector<OperandType> inExpectedTypes; |
| std::vector<OperandType> outExpectedTypes; |
| if (inputType == OperandType::TENSOR_FLOAT32) { |
| version = Version::ANDROID_OC_MR1; |
| inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::INT32}; |
| outExpectedTypes = {OperandType::TENSOR_FLOAT32}; |
| } else if (inputType == OperandType::TENSOR_FLOAT16) { |
| version = Version::ANDROID_Q; |
| inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::INT32}; |
| outExpectedTypes = {OperandType::TENSOR_FLOAT16}; |
| } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { |
| version = Version::ANDROID_OC_MR1; |
| inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::INT32}; |
| outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM}; |
| } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { |
| version = Version::ANDROID_R; |
| inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED, OperandType::INT32}; |
| outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED}; |
| } else { |
| NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; |
| } |
| if (inputIndexes.size() == 3) { |
| inExpectedTypes.push_back(OperandType::BOOL); |
| version = combineVersions(version, Version::ANDROID_Q); |
| } else { |
| version = combineVersions(version, Version::ANDROID_OC_MR1); |
| } |
| NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, |
| outputIndexes, outExpectedTypes)); |
| return version; |
| } |
| case OperationType::SPACE_TO_DEPTH: { |
| NN_VALIDATE((inputIndexes.size() == 3 || inputIndexes.size() == 2) && |
| outputIndexes.size() == 1) |
| << "Invalid number of input operands (" << inputIndexes.size() |
| << ", expected 3 or 2) or output operands (" << outputIndexes.size() |
| << ", expected 1) for operation " << opType; |
| auto inputType = operands[inputIndexes[0]].type; |
| Version version; |
| std::vector<OperandType> inExpectedTypes; |
| std::vector<OperandType> outExpectedTypes; |
| if (inputType == OperandType::TENSOR_FLOAT32) { |
| version = Version::ANDROID_OC_MR1; |
| inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::INT32}; |
| outExpectedTypes = {OperandType::TENSOR_FLOAT32}; |
| } else if (inputType == OperandType::TENSOR_FLOAT16) { |
| version = Version::ANDROID_Q; |
| inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::INT32}; |
| outExpectedTypes = {OperandType::TENSOR_FLOAT16}; |
| } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { |
| version = Version::ANDROID_OC_MR1; |
| inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::INT32}; |
| outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM}; |
| } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { |
| version = Version::ANDROID_R; |
| inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED, OperandType::INT32}; |
| outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED}; |
| } else { |
| NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; |
| } |
| if (inputIndexes.size() == 3) { |
| inExpectedTypes.push_back(OperandType::BOOL); |
| version = combineVersions(version, Version::ANDROID_Q); |
| } else { |
| version = combineVersions(version, Version::ANDROID_OC_MR1); |
| } |
| NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, |
| outputIndexes, outExpectedTypes)); |
| return version; |
| } |
| case OperationType::EMBEDDING_LOOKUP: { |
| NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1) |
| << invalidInOutNumberMessage(2, 1); |
| auto inputType = operands[inputIndexes[1]].type; |
| NN_VALIDATE(inputType == OperandType::TENSOR_FLOAT16 || |
| inputType == OperandType::TENSOR_FLOAT32 || |
| inputType == OperandType::TENSOR_INT32 || |
| inputType == OperandType::TENSOR_QUANT8_ASYMM || |
| inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) |
| << "Unsupported input tensor type for operation " << opType; |
| Version version; |
| std::vector<OperandType> inExpectedTypes = {OperandType::TENSOR_INT32, inputType}; |
| std::vector<OperandType> outExpectedTypes = {inputType}; |
| if (inputType == OperandType::TENSOR_FLOAT16 || |
| inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { |
| version = Version::ANDROID_R; |
| } else if (inputType == OperandType::TENSOR_INT32 || |
| inputType == OperandType::TENSOR_QUANT8_ASYMM) { |
| version = Version::ANDROID_Q; |
| } else { |
| version = Version::ANDROID_OC_MR1; |
| } |
| NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, |
| outputIndexes, outExpectedTypes)); |
| return version; |
| } |
| case OperationType::HASHTABLE_LOOKUP: { |
| NN_VALIDATE(inputIndexes.size() == 3 && outputIndexes.size() == 2) |
| << invalidInOutNumberMessage(3, 2); |
| auto inputType = operands[inputIndexes[2]].type; |
| NN_VALIDATE(inputType == OperandType::TENSOR_FLOAT32 || |
| inputType == OperandType::TENSOR_INT32 || |
| inputType == OperandType::TENSOR_QUANT8_ASYMM) |
| << "Unsupported input tensor type for operation " << opType; |
| std::vector<OperandType> inExpectedTypes = {OperandType::TENSOR_INT32, |
| OperandType::TENSOR_INT32, inputType}; |
| std::vector<OperandType> outExpectedTypes = {inputType, |
| OperandType::TENSOR_QUANT8_ASYMM}; |
| NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, |
| outputIndexes, outExpectedTypes)); |
| return Version::ANDROID_OC_MR1; |
| } |
| case OperationType::LSH_PROJECTION: { |
| NN_VALIDATE(inputIndexes.size() == 4 && outputIndexes.size() == 1) |
| << invalidInOutNumberMessage(4, 1); |
| auto inputType = operands[inputIndexes[1]].type; |
| NN_VALIDATE(inputType == OperandType::TENSOR_FLOAT16 || |
| inputType == OperandType::TENSOR_FLOAT32 || |
| inputType == OperandType::TENSOR_INT32 || |
| inputType == OperandType::TENSOR_QUANT8_ASYMM) |
| << "Unsupported input tensor type for operation " << opType; |
| auto hashType = operands[inputIndexes[0]].type; |
| Version version; |
| std::vector<OperandType> inExpectedTypes; |
| if (hashType == OperandType::TENSOR_FLOAT16) { |
| version = Version::ANDROID_Q; |
| inExpectedTypes = { |
| OperandType::TENSOR_FLOAT16, |
| inputType, |
| OperandType::TENSOR_FLOAT16, |
| OperandType::INT32, |
| }; |
| } else if (hashType == OperandType::TENSOR_FLOAT32) { |
| version = Version::ANDROID_OC_MR1; |
| inExpectedTypes = { |
| OperandType::TENSOR_FLOAT32, |
| inputType, |
| OperandType::TENSOR_FLOAT32, |
| OperandType::INT32, |
| }; |
| } else { |
| NN_VALIDATE_FAIL() << "Unsupported hash tensor type for operation " << opType; |
| } |
| std::vector<OperandType> outExpectedTypes = {OperandType::TENSOR_INT32}; |
| NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, |
| outputIndexes, outExpectedTypes)); |
| return version; |
| } |
| case OperationType::BIDIRECTIONAL_SEQUENCE_LSTM: { |
| const uint32_t kNumOutputs = 2; |
| const uint32_t kNumOutputsMerged = 1; |
| const uint32_t kNumOutputsWithState = 6; |
| const uint32_t kNumOutputsMergedWithState = 5; |
| NN_VALIDATE(inputIndexes.size() == 61 && |
| (outputIndexes.size() == kNumOutputs || |
| outputIndexes.size() == kNumOutputsMerged || |
| outputIndexes.size() == kNumOutputsWithState || |
| outputIndexes.size() == kNumOutputsMergedWithState)) |
| << "Invalid number of input operands (" << inputIndexes.size() |
| << ", expected 61) or output operands (" << outputIndexes.size() |
| << ", expected 1, 2, 5 or 6) for operation " << opType; |
| |
| std::vector<OperandType> inExpectedTypes; |
| auto inputType = operands[inputIndexes[0]].type; |
| NN_VALIDATE(inputType == OperandType::TENSOR_FLOAT32 || |
| inputType == OperandType::TENSOR_FLOAT16) |
| << "Unsupported input tensor type for operation " << opType; |
| |
| inExpectedTypes = {}; |
| for (int i = 0; i < 48; ++i) { |
| inExpectedTypes.push_back(inputType); |
| } |
| inExpectedTypes.push_back(OperandType::INT32); |
| inExpectedTypes.push_back(inputType == OperandType::TENSOR_FLOAT32 |
| ? OperandType::FLOAT32 |
| : OperandType::FLOAT16); |
| inExpectedTypes.push_back(inputType == OperandType::TENSOR_FLOAT32 |
| ? OperandType::FLOAT32 |
| : OperandType::FLOAT16); |
| inExpectedTypes.push_back(OperandType::BOOL); |
| inExpectedTypes.push_back(OperandType::BOOL); |
| for (int i = 0; i < 8; ++i) { |
| inExpectedTypes.push_back(inputType); |
| } |
| |
| Version version = Version::ANDROID_Q; |
| if (outputIndexes.size() == kNumOutputsWithState || |
| outputIndexes.size() == kNumOutputsMergedWithState) { |
| version = Version::ANDROID_R; |
| } |
| std::vector<OperandType> outExpectedTypes(outputIndexes.size(), inputType); |
| NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, |
| outputIndexes, outExpectedTypes)); |
| return version; |
| } |
| case OperationType::LSTM: { |
| NN_VALIDATE((inputIndexes.size() == 23 || inputIndexes.size() == 27) && |
| outputIndexes.size() == 4) |
| << "Invalid number of input operands (" << inputIndexes.size() |
| << ", expected 23 or 27) or output operands (" << outputIndexes.size() |
| << ", expected 4) for operation " << opType; |
| std::vector<OperandType> inExpectedTypes; |
| std::vector<OperandType> outExpectedTypes; |
| auto inputType = operands[inputIndexes[0]].type; |
| NN_VALIDATE(inputType == OperandType::TENSOR_FLOAT32 || |
| inputType == OperandType::TENSOR_FLOAT16) |
| << "Unsupported input tensor type for operation " << opType; |
| |
| Version version = Version::ANDROID_OC_MR1; |
| inExpectedTypes = {inputType, inputType, inputType, inputType, inputType, |
| inputType, inputType, inputType, inputType, inputType, |
| inputType, inputType, inputType, inputType, inputType, |
| inputType, inputType, inputType, inputType, inputType, |
| OperandType::INT32}; |
| if (inputType == OperandType::TENSOR_FLOAT32) { |
| inExpectedTypes.push_back(OperandType::FLOAT32); |
| inExpectedTypes.push_back(OperandType::FLOAT32); |
| } else { |
| version = Version::ANDROID_Q; |
| inExpectedTypes.push_back(OperandType::FLOAT16); |
| inExpectedTypes.push_back(OperandType::FLOAT16); |
| } |
| |
| outExpectedTypes = {inputType, inputType, inputType, inputType}; |
| if (inputIndexes.size() == 23) { |
| version = combineVersions(version, Version::ANDROID_OC_MR1); |
| } else { |
| version = combineVersions(version, Version::ANDROID_Q); |
| for (int i = 0; i < 4; ++i) { |
| inExpectedTypes.push_back(inputType); |
| } |
| } |
| NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, |
| outputIndexes, outExpectedTypes)); |
| return version; |
| } |
| case OperationType::QUANTIZED_16BIT_LSTM: { |
| NN_VALIDATE(inputIndexes.size() == 15 && outputIndexes.size() == 2) |
| << invalidInOutNumberMessage(15, 2); |
| std::vector<OperandType> inExpectedTypes = { |
| OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT8_ASYMM, |
| OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT8_ASYMM, |
| OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT8_ASYMM, |
| OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT8_ASYMM, |
| OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_INT32, |
| OperandType::TENSOR_INT32, OperandType::TENSOR_INT32, |
| OperandType::TENSOR_INT32, OperandType::TENSOR_QUANT16_SYMM, |
| OperandType::TENSOR_QUANT8_ASYMM}; |
| std::vector<OperandType> outExpectedTypes = {OperandType::TENSOR_QUANT16_SYMM, |
| OperandType::TENSOR_QUANT8_ASYMM}; |
| NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, |
| outputIndexes, outExpectedTypes)); |
| return Version::ANDROID_Q; |
| } |
| case OperationType::RANDOM_MULTINOMIAL: { |
| NN_VALIDATE(inputIndexes.size() == 3 && outputIndexes.size() == 1) |
| << invalidInOutNumberMessage(3, 1); |
| OperandType inputType = operands[inputIndexes[0]].type; |
| std::vector<OperandType> inExpectedTypes; |
| if (inputType == OperandType::TENSOR_FLOAT32 || |
| inputType == OperandType::TENSOR_FLOAT16) { |
| inExpectedTypes = {inputType, OperandType::INT32, OperandType::TENSOR_INT32}; |
| } else { |
| NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; |
| } |
| std::vector<OperandType> outExpectedTypes = {OperandType::TENSOR_INT32}; |
| NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, |
| outputIndexes, outExpectedTypes)); |
| return Version::ANDROID_Q; |
| } |
| case OperationType::RNN: { |
| NN_VALIDATE(inputIndexes.size() == 6 && outputIndexes.size() == 2) |
| << invalidInOutNumberMessage(6, 2); |
| OperandType inputType = operands[inputIndexes[0]].type; |
| Version version; |
| std::vector<OperandType> inExpectedTypes; |
| std::vector<OperandType> outExpectedTypes; |
| if (inputType == OperandType::TENSOR_FLOAT32) { |
| version = Version::ANDROID_OC_MR1; |
| inExpectedTypes = { |
| OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, |
| OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, |
| OperandType::TENSOR_FLOAT32, OperandType::INT32, |
| }; |
| outExpectedTypes = { |
| OperandType::TENSOR_FLOAT32, |
| OperandType::TENSOR_FLOAT32, |
| }; |
| } else if (inputType == OperandType::TENSOR_FLOAT16) { |
| version = Version::ANDROID_Q; |
| inExpectedTypes = { |
| OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, |
| OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, |
| OperandType::TENSOR_FLOAT16, OperandType::INT32, |
| }; |
| outExpectedTypes = { |
| OperandType::TENSOR_FLOAT16, |
| OperandType::TENSOR_FLOAT16, |
| }; |
| } else { |
| NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; |
| } |
| NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, |
| outputIndexes, outExpectedTypes)); |
| return version; |
| } |
| case OperationType::SVDF: { |
| NN_VALIDATE(inputIndexes.size() == 7 && outputIndexes.size() == 2) |
| << invalidInOutNumberMessage(7, 2); |
| Version version; |
| OperandType inputType = operands[inputIndexes[0]].type; |
| if (inputType == OperandType::TENSOR_FLOAT32) { |
| version = Version::ANDROID_OC_MR1; |
| } else if (inputType == OperandType::TENSOR_FLOAT16) { |
| version = Version::ANDROID_Q; |
| } else { |
| NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; |
| } |
| std::vector<OperandType> inExpectedTypes = { |
| inputType, inputType, inputType, inputType, |
| inputType, OperandType::INT32, OperandType::INT32, |
| }; |
| std::vector<OperandType> outExpectedTypes = {inputType, inputType}; |
| NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, |
| outputIndexes, outExpectedTypes)); |
| return version; |
| } |
| case OperationType::BATCH_TO_SPACE_ND: { |
| NN_VALIDATE((inputIndexes.size() == 3 || inputIndexes.size() == 2) && |
| outputIndexes.size() == 1) |
| << "Invalid number of input operands (" << inputIndexes.size() |
| << ", expected 3 or 2) or output operands (" << outputIndexes.size() |
| << ", expected 1) for operation " << opType; |
| auto inputType = operands[inputIndexes[0]].type; |
| Version version = Version::ANDROID_OC_MR1; |
| std::vector<OperandType> inExpectedTypes; |
| std::vector<OperandType> outExpectedTypes; |
| if (inputType == OperandType::TENSOR_FLOAT32) { |
| inExpectedTypes = { |
| OperandType::TENSOR_FLOAT32, |
| OperandType::TENSOR_INT32, |
| }; |
| outExpectedTypes = {OperandType::TENSOR_FLOAT32}; |
| } else if (inputType == OperandType::TENSOR_FLOAT16) { |
| version = Version::ANDROID_Q; |
| inExpectedTypes = { |
| OperandType::TENSOR_FLOAT16, |
| OperandType::TENSOR_INT32, |
| }; |
| outExpectedTypes = {OperandType::TENSOR_FLOAT16}; |
| } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { |
| inExpectedTypes = { |
| OperandType::TENSOR_QUANT8_ASYMM, |
| OperandType::TENSOR_INT32, |
| }; |
| outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM}; |
| } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { |
| version = Version::ANDROID_R; |
| inExpectedTypes = { |
| OperandType::TENSOR_QUANT8_ASYMM_SIGNED, |
| OperandType::TENSOR_INT32, |
| }; |
| outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED}; |
| } else { |
| NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; |
| } |
| if (inputIndexes.size() == 3) { |
| inExpectedTypes.push_back(OperandType::BOOL); |
| version = combineVersions(version, Version::ANDROID_Q); |
| } else { |
| version = combineVersions(version, Version::ANDROID_P); |
| } |
| NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, |
| outputIndexes, outExpectedTypes)); |
| return version; |
| } |
| case OperationType::SPACE_TO_BATCH_ND: { |
| NN_VALIDATE((inputIndexes.size() == 4 || inputIndexes.size() == 3) && |
| outputIndexes.size() == 1) |
| << "Invalid number of input operands (" << inputIndexes.size() |
| << ", expected 4 or 3) or output operands (" << outputIndexes.size() |
| << ", expected 1) for operation " << opType; |
| auto inputType = operands[inputIndexes[0]].type; |
| Version version = Version::ANDROID_OC_MR1; |
| std::vector<OperandType> inExpectedTypes; |
| std::vector<OperandType> outExpectedTypes; |
| if (inputType == OperandType::TENSOR_FLOAT32) { |
| inExpectedTypes = { |
| OperandType::TENSOR_FLOAT32, |
| OperandType::TENSOR_INT32, |
| OperandType::TENSOR_INT32, |
| }; |
| outExpectedTypes = {OperandType::TENSOR_FLOAT32}; |
| } else if (inputType == OperandType::TENSOR_FLOAT16) { |
| version = Version::ANDROID_Q; |
| inExpectedTypes = { |
| OperandType::TENSOR_FLOAT16, |
| OperandType::TENSOR_INT32, |
| OperandType::TENSOR_INT32, |
| }; |
| outExpectedTypes = {OperandType::TENSOR_FLOAT16}; |
| } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { |
| if (operands[inputIndexes[0]].zeroPoint != 0) { |
| version = Version::ANDROID_Q; |
| } |
| inExpectedTypes = { |
| OperandType::TENSOR_QUANT8_ASYMM, |
| OperandType::TENSOR_INT32, |
| OperandType::TENSOR_INT32, |
| }; |
| outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM}; |
| } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { |
| version = Version::ANDROID_R; |
| inExpectedTypes = { |
| OperandType::TENSOR_QUANT8_ASYMM_SIGNED, |
| OperandType::TENSOR_INT32, |
| OperandType::TENSOR_INT32, |
| }; |
| outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED}; |
| } else { |
| NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; |
| } |
| if (inputIndexes.size() == 4) { |
| inExpectedTypes.push_back(OperandType::BOOL); |
| version = combineVersions(version, Version::ANDROID_Q); |
| } else { |
| version = combineVersions(version, Version::ANDROID_P); |
| } |
| NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, |
| outputIndexes, outExpectedTypes)); |
| return version; |
| } |
| case OperationType::PAD: { |
| NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1) |
| << invalidInOutNumberMessage(2, 1); |
| auto inputType = operands[inputIndexes[0]].type; |
| Version version; |
| std::vector<OperandType> inExpectedTypes; |
| std::vector<OperandType> outExpectedTypes; |
| if (inputType == OperandType::TENSOR_FLOAT32) { |
| version = Version::ANDROID_P; |
| inExpectedTypes = { |
| OperandType::TENSOR_FLOAT32, |
| OperandType::TENSOR_INT32, |
| }; |
| outExpectedTypes = {OperandType::TENSOR_FLOAT32}; |
| } else if (inputType == OperandType::TENSOR_FLOAT16) { |
| version = Version::ANDROID_Q; |
| inExpectedTypes = { |
| OperandType::TENSOR_FLOAT16, |
| OperandType::TENSOR_INT32, |
| }; |
| outExpectedTypes = {OperandType::TENSOR_FLOAT16}; |
| } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM || |
| inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { |
| if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { |
| version = Version::ANDROID_R; |
| } else { |
| if (operands[inputIndexes[0]].zeroPoint == 0) { |
| version = Version::ANDROID_P; |
| } else { |
| version = Version::ANDROID_Q; |
| } |
| } |
| inExpectedTypes = { |
| inputType, |
| OperandType::TENSOR_INT32, |
| }; |
| outExpectedTypes = {inputType}; |
| } else { |
| NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; |
| } |
| const auto inputRank = operands[inputIndexes[0]].dimensions.size(); |
| NN_VALIDATE_LE(inputRank, 4u) |
| << "Unsupported input tensor rank for operation " << opType; |
| NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, |
| outputIndexes, outExpectedTypes)); |
| return version; |
| } |
| case OperationType::PAD_V2: { |
| NN_VALIDATE(inputIndexes.size() == 3 && outputIndexes.size() == 1) |
| << invalidInOutNumberMessage(3, 1); |
| auto inputType = operands[inputIndexes[0]].type; |
| Version version; |
| std::vector<OperandType> inExpectedTypes; |
| std::vector<OperandType> outExpectedTypes; |
| if (inputType == OperandType::TENSOR_FLOAT32) { |
| version = Version::ANDROID_Q; |
| inExpectedTypes = { |
| OperandType::TENSOR_FLOAT32, |
| OperandType::TENSOR_INT32, |
| OperandType::FLOAT32, |
| }; |
| outExpectedTypes = {OperandType::TENSOR_FLOAT32}; |
| } else if (inputType == OperandType::TENSOR_FLOAT16) { |
| version = Version::ANDROID_Q; |
| inExpectedTypes = { |
| OperandType::TENSOR_FLOAT16, |
| OperandType::TENSOR_INT32, |
| OperandType::FLOAT16, |
| }; |
| outExpectedTypes = {OperandType::TENSOR_FLOAT16}; |
| } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM || |
| inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { |
| if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { |
| version = Version::ANDROID_R; |
| } else { |
| version = Version::ANDROID_Q; |
| } |
| inExpectedTypes = { |
| inputType, |
| OperandType::TENSOR_INT32, |
| OperandType::INT32, |
| }; // TODO(b/116699425): Make it UINT8. |
| outExpectedTypes = {inputType}; |
| } else { |
| NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; |
| } |
| const auto inputRank = operands[inputIndexes[0]].dimensions.size(); |
| NN_VALIDATE_LE(inputRank, 4u) |
| << "Unsupported input tensor rank for operation " << opType; |
| NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, |
| outputIndexes, outExpectedTypes)); |
| return version; |
| } |
| case OperationType::CAST: { |
| NN_VALIDATE(inputIndexes.size() == 1 && outputIndexes.size() == 1) |
| << invalidInOutNumberMessage(1, 1); |
| auto inputOperand = operands[inputIndexes[0]]; |
| auto outputOperand = operands[outputIndexes[0]]; |
| auto inputType = inputOperand.type; |
| auto outputType = outputOperand.type; |
| Version version; |
| std::vector<OperandType> inExpectedTypes; |
| std::vector<OperandType> outExpectedTypes; |
|