| /* |
| * Copyright (C) 2019 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #define LOG_TAG "MetaModel" |
| |
| #include "MetaModel.h" |
| |
| #include <algorithm> |
| #include <map> |
| #include <numeric> |
| #include <set> |
| #include <sstream> |
| #include <type_traits> |
| #include <utility> |
| #include <vector> |
| |
| #include "GraphDump.h" |
| #include "LegacyUtils.h" |
| #include "nnapi/TypeUtils.h" |
| #include "nnapi/Types.h" |
| #include "nnapi/Validation.h" |
| |
| namespace android::nn { |
| |
| namespace { |
| |
| // Add an element to the end of the vector, set it to the specified value, and |
| // return a pair consisting of the index of the new element and a pointer to the |
| // new element. |
| template <class T> |
| std::pair<uint32_t, T*> extend(std::vector<T>* vec, const T& val) { |
| vec->push_back(val); |
| return {vec->size() - 1, &vec->back()}; |
| } |
| |
| // Add an element to the end of the vector and return a pair consisting of the |
| // index of the new element and a pointer to the new element. |
| template <class T> |
| std::pair<uint32_t, T*> extend(std::vector<T>* vec) { |
| return extend(vec, {}); |
| } |
| |
| bool invalid(const Model& model, Version version, bool strictSlicing) { |
| // A model must have at least one operation. However, it's possible that a |
| // slice has no operations (because no operations from the original model |
| // are compliant with the sliced model type). In this case, the sliced |
| // model would be invalid. |
| const bool looksEmpty = (model.main.operations.size() == 0); |
| if (strictSlicing) { |
| CHECK_EQ(looksEmpty, (model.main.operands.size() == 0)); |
| } |
| if (looksEmpty) return true; |
| |
| // A model must have at least one output. However, it's possible for a |
| // model to contain dead operations (i.e., outputs on which no model outputs |
| // are data dependent). A slice might contain only dead operations, and |
| // hence have no model outputs. In this case, the sliced model would be |
| // invalid. |
| if (model.main.outputIndexes.size() == 0) return true; |
| |
| // We shouldn't have to check whether the model is valid. However, it could |
| // be invalid if there is an error in the slicing algorithm. |
| auto maybeVersion = validate(model); |
| if (!maybeVersion.has_value()) { |
| LOG(WARNING) << "Sliced model fails validate(): " << maybeVersion.error(); |
| CHECK(!strictSlicing); |
| return true; |
| } |
| if (maybeVersion.value() > version) { |
| LOG(WARNING) << "Sliced model fails validate(): insufficient version (" |
| << maybeVersion.value() << " vs " << version << ")"; |
| CHECK(!strictSlicing); |
| return true; |
| } |
| |
| return false; |
| } |
| |
| } // anonymous namespace |
| |
| MetaModel::MetaModel(Model model, bool strictSlicing) |
| : mModel(std::move(model)), |
| mModelMinimumSupportedVersion(validate(mModel).value()), |
| mStrictSlicing(strictSlicing) {} |
| |
| MetaModel::ReturnedSlice MetaModel::getSlice(Version version) const { |
| // All slices of versions of at least mModelMinimumSupportedVersion are identical, so do not |
| // create more than one such slice. |
| version = std::min(version, mModelMinimumSupportedVersion); |
| |
| auto& slice = mCachedSlices[version]; |
| if (slice.mState == SliceState::UNINITIALIZED) { |
| slice = makeSlice(version); |
| } |
| if (slice.mState == SliceState::INVALID) { |
| return {}; |
| } |
| return MetaModel::ReturnedSlice(std::make_pair( |
| slice.mModel, Mapper([&slice](uint32_t slicedOperationIndex) { |
| return slice.mSlicedOperationIndexToOrigIndex.at(slicedOperationIndex); |
| }))); |
| } |
| |
| // Utility class for makeSlice(). |
| // |
| // For each output operand of a noncompliant operation that is the input |
| // operand of at least one compliant operation, we will ensure that there is |
| // a sliced model input whose "type" is that of the output operand. This is |
| // a map from operand "type" (in the original model) to model input operand |
| // index (in the sliced model). We only use the subset of the fields that are |
| // relevant (OperandType, dimensions, scale, zeroPoint, extraParams), but |
| // exclude irrelevant fields from the map key (lifetime, location). |
| // |
| // We also use this map for model input operands of the original model that |
| // become input operands of the sliced model. This means that an original |
| // model input operand might be commoned with other original model input |
| // operands and/or with original model temporary operands. |
| class MetaModel::OrigOperandToSlicedInputOperandIndex { |
| public: |
| // `slicedOperands` and `slicedInputIndexes` will be modified as part of |
| // OrigOperandToSlicedInputOperandIndex::getIndex. `slicedVersion`, `operandValuesSize`, and |
| // `poolSizes` are used as a check to ensure that the sliced operand is valid and compliant with |
| // the sliced version. `operandValuesSize` is the size of the operand values in the sliced model |
| // (which is the same as the original model). `poolSizes` is the size of the memories in the |
| // sliced model (which is the same as the original model). |
| OrigOperandToSlicedInputOperandIndex(std::vector<Operand>* slicedOperands, |
| std::vector<uint32_t>* slicedInputIndexes, |
| Version slicedVersion, size_t operandValuesSize, |
| std::vector<size_t> poolSizes) |
| : mSlicedOperands(*slicedOperands), |
| mSlicedInputIndexes(*slicedInputIndexes), |
| kSlicedVersion(slicedVersion), |
| kOperandValuesSize(operandValuesSize), |
| kPoolSizes(std::move(poolSizes)) {} |
| |
| // Given an operand from the original model, return the index of the |
| // corresponding model input operand from the sliced model. Creates a |
| // new operand in the sliced model if necessary. |
| uint32_t getIndex(Operand operand) { |
| CHECK(operand.lifetime == Operand::LifeTime::SUBGRAPH_INPUT || |
| operand.lifetime == Operand::LifeTime::SUBGRAPH_OUTPUT || |
| operand.lifetime == Operand::LifeTime::TEMPORARY_VARIABLE); |
| |
| // Lookup |
| auto it = mMap.find(operand); |
| if (it != mMap.end()) { |
| VLOG(COMPILATION) << "OrigOperandToSlicedInputOperandIndex::getIndex looked for " |
| << operand << " and found " << it->second << ": " << it->first; |
| return it->second; |
| } |
| |
| // Create |
| operand.lifetime = Operand::LifeTime::SUBGRAPH_INPUT; |
| operand.location = {}; |
| |
| // Note that the sliced model does not contain any referenced subgraphs, so both `subgraphs` |
| // and `subgraphVersionCache` are empty. |
| const std::vector<Model::Subgraph> subgraphs; |
| auto subgraphVersionCache = createSubgraphVersionCache(subgraphs.size()); |
| const auto minimumSupportedOperandVersion = |
| validateOperandAndAnythingItDependsOn(operand, kOperandValuesSize, kPoolSizes, |
| subgraphs, subgraphVersionCache.get()) |
| .value(); |
| CHECK_LE(minimumSupportedOperandVersion, kSlicedVersion); |
| |
| uint32_t slicedOperandIndex = extend(&mSlicedOperands, operand).first; |
| mMap[operand] = slicedOperandIndex; |
| extend(&mSlicedInputIndexes, slicedOperandIndex); |
| VLOG(COMPILATION) << "OrigOperandToSlicedInputOperandIndex::getIndex created " |
| << slicedOperandIndex << ": " << operand; |
| return slicedOperandIndex; |
| } |
| |
| private: |
| class Compare { |
| public: |
| bool operator()(const Operand& a, const Operand& b) const { |
| if (a.type != b.type) { |
| return a.type < b.type; |
| } |
| if (a.dimensions != b.dimensions) { |
| return a.dimensions < b.dimensions; |
| } |
| if (a.scale != b.scale) { |
| return a.scale < b.scale; |
| } |
| if (a.zeroPoint != b.zeroPoint) { |
| return a.zeroPoint < b.zeroPoint; |
| } |
| return compare(a.extraParams, b.extraParams); |
| } |
| |
| private: |
| static bool compare(const Operand::SymmPerChannelQuantParams& a, |
| const Operand::SymmPerChannelQuantParams& b) { |
| if (a.scales != b.scales) { |
| return a.scales < b.scales; |
| } |
| return a.channelDim < b.channelDim; |
| } |
| static bool compare(const Operand::ExtraParams& a, const Operand::ExtraParams& b) { |
| if (a.index() != b.index()) { |
| return a.index() < b.index(); |
| } |
| if (std::holds_alternative<Operand::SymmPerChannelQuantParams>(a)) { |
| return compare(std::get<Operand::SymmPerChannelQuantParams>(a), |
| std::get<Operand::SymmPerChannelQuantParams>(b)); |
| } |
| if (std::holds_alternative<Operand::ExtensionParams>(a)) { |
| return std::get<Operand::ExtensionParams>(a) < |
| std::get<Operand::ExtensionParams>(b); |
| } |
| if (std::holds_alternative<Operand::NoParams>(a)) { |
| return false; |
| } |
| CHECK(false) << "Unexpected"; |
| return false; |
| } |
| }; |
| std::map<Operand, uint32_t, Compare> mMap; |
| std::vector<Operand>& mSlicedOperands; |
| std::vector<uint32_t>& mSlicedInputIndexes; |
| const Version kSlicedVersion; |
| const size_t kOperandValuesSize; |
| const std::vector<size_t> kPoolSizes; |
| }; |
| |
| void MetaModel::processOperations( |
| Slice* slice, std::map<uint32_t, uint32_t>* origOperandIndexToSlicedIndex, |
| OrigOperandToSlicedInputOperandIndex* origOperandToSlicedInputOperandIndex, |
| const std::set<uint32_t>& noncompliantOperations, |
| const std::set<uint32_t>& inputOperandIndexesOfCompliantOperations) const { |
| const auto& origOperands = mModel.main.operands; |
| const auto& origOperations = mModel.main.operations; |
| auto& slicedOperands = slice->mModel.main.operands; |
| auto& slicedOperations = slice->mModel.main.operations; |
| |
| std::vector<uint32_t> origOperandNumberOfConsumers = |
| countNumberOfConsumers(origOperands.size(), origOperations).value(); |
| |
| for (uint32_t origOperationIndex = 0; origOperationIndex < origOperations.size(); |
| ++origOperationIndex) { |
| const Operation& origOperation = origOperations[origOperationIndex]; |
| |
| if (noncompliantOperations.count(origOperationIndex)) { |
| for (uint32_t output : origOperation.outputs) { |
| if (!inputOperandIndexesOfCompliantOperations.count(output)) { |
| continue; |
| } |
| const uint32_t slicedIndex = |
| origOperandToSlicedInputOperandIndex->getIndex(origOperands[output]); |
| (*origOperandIndexToSlicedIndex)[output] = slicedIndex; |
| VLOG(COMPILATION) |
| << "origOperandIndexToSlicedIndex noncompliant output processing created " |
| << output << " -> " << slicedIndex << ": " << slicedOperands[slicedIndex]; |
| } |
| } else { |
| slice->mSlicedOperationIndexToOrigIndex.push_back(origOperationIndex); |
| Operation& slicedOperation = *extend(&slicedOperations).second; |
| CHECK_EQ(slice->mSlicedOperationIndexToOrigIndex.size(), slicedOperations.size()); |
| |
| slicedOperation.type = origOperation.type; |
| |
| // Model is topologically sorted, so all operation inputs must be |
| // present in origOperandIndexToSlicedIndex, and no operation |
| // outputs may be. |
| |
| // Operation inputs |
| // - Fill in slicedOperation.inputs |
| slicedOperation.inputs.resize(origOperation.inputs.size()); |
| std::transform( |
| origOperation.inputs.begin(), origOperation.inputs.end(), |
| slicedOperation.inputs.begin(), |
| [&origOperandIndexToSlicedIndex, &slicedOperands](uint32_t origOperandIndex) { |
| uint32_t slicedOperandIndex = |
| origOperandIndexToSlicedIndex->at(origOperandIndex); |
| VLOG(COMPILATION) << "origOperandIndexToSlicedIndex compliant input " |
| "processing created " |
| << origOperandIndex << " -> " << slicedOperandIndex |
| << ": " << slicedOperands[slicedOperandIndex]; |
| return slicedOperandIndex; |
| }); |
| |
| // Operation outputs |
| // - Add new operands to slicedOperands |
| // - Update origOperandIndexToSlicedIndex |
| // - Fill in slicedOperation.outputs |
| // - Record as a model output, if necessary |
| const uint32_t firstOutputSlicedOperandIndex = slicedOperands.size(); |
| slicedOperands.resize(firstOutputSlicedOperandIndex + origOperation.outputs.size()); |
| slicedOperation.outputs.resize(origOperation.outputs.size()); |
| for (uint32_t outputNum = 0; outputNum < slicedOperation.outputs.size(); ++outputNum) { |
| uint32_t origOperandIndex = origOperation.outputs[outputNum]; |
| uint32_t slicedOperandIndex = firstOutputSlicedOperandIndex + outputNum; |
| auto& slicedOperand = slicedOperands[slicedOperandIndex]; |
| const auto& origOperand = origOperands[origOperandIndex]; |
| slicedOperand = origOperand; |
| |
| CHECK_EQ(origOperandIndexToSlicedIndex->count(origOperandIndex), size_t(0)); |
| (*origOperandIndexToSlicedIndex)[origOperandIndex] = slicedOperandIndex; |
| slicedOperation.outputs[outputNum] = slicedOperandIndex; |
| |
| const auto subgraphOutputLifetime = Operand::LifeTime::SUBGRAPH_OUTPUT; |
| if (!inputOperandIndexesOfCompliantOperations.count(origOperandIndex) && |
| origOperandNumberOfConsumers[origOperandIndex] != 0) { |
| // Was consumed only by noncompliant operations; convert to |
| // an output of the sliced model. |
| slicedOperand.lifetime = subgraphOutputLifetime; |
| } |
| |
| VLOG(COMPILATION) << "origOperandIndexToSlicedIndex compliant output created " |
| << origOperandIndex << " -> " << slicedOperandIndex << ": " |
| << slicedOperand; |
| |
| if (slicedOperand.lifetime == subgraphOutputLifetime) { |
| extend(&slice->mModel.main.outputIndexes, slicedOperandIndex); |
| } |
| } |
| } |
| } |
| } |
| |
| std::set<uint32_t> MetaModel::getNoncompliantOperations(Version version) const { |
| const auto [operandValuesSize, poolSizes] = getMemorySizes(mModel); |
| |
| auto subgraphVersionCache = createSubgraphVersionCache(mModel.referenced.size()); |
| std::set<uint32_t> noncompliantOperations; |
| for (uint32_t i = 0; i < mModel.main.operations.size(); ++i) { |
| const auto& operation = mModel.main.operations[i]; |
| const auto minSupportedVersion = |
| validateOperationAndAnythingItDependsOn( |
| operation, mModel.main.operands, operandValuesSize, poolSizes, |
| mModel.referenced, subgraphVersionCache.get()) |
| .value(); |
| if (minSupportedVersion > version) { |
| noncompliantOperations.insert(i); |
| } |
| } |
| return noncompliantOperations; |
| } |
| |
| MetaModel::Slice MetaModel::makeSlice(Version version) const { |
| Slice slice; |
| |
| // Quickly return if the model is already compliant with `version` |
| if (version >= mModelMinimumSupportedVersion) { |
| slice.mModel = mModel; |
| slice.mSlicedOperationIndexToOrigIndex = |
| std::vector<uint32_t>(mModel.main.operations.size()); |
| std::iota(slice.mSlicedOperationIndexToOrigIndex.begin(), |
| slice.mSlicedOperationIndexToOrigIndex.end(), 0u); |
| slice.mState = SliceState::NORMAL; |
| return slice; |
| } |
| |
| const auto& origOperands = mModel.main.operands; |
| const auto& origOperations = mModel.main.operations; |
| auto& slicedOperands = slice.mModel.main.operands; |
| |
| // Indexes of elements of noncompliant origOperations |
| std::set<uint32_t> noncompliantOperations = getNoncompliantOperations(version); |
| |
| // Check if any compliant operations require a subgraph. |
| bool someCompliantOperationHasASubgraphOperand = false; |
| if (!mModel.referenced.empty()) { |
| for (size_t i = 0; i < mModel.main.operations.size(); ++i) { |
| const auto& operation = mModel.main.operations[i]; |
| if (noncompliantOperations.count(i) > 0) { |
| continue; |
| } |
| const auto isSubgraph = [&origOperands](uint32_t opndIdx) { |
| return origOperands[opndIdx].lifetime == Operand::LifeTime::SUBGRAPH; |
| }; |
| if (std::any_of(operation.inputs.begin(), operation.inputs.end(), isSubgraph)) { |
| someCompliantOperationHasASubgraphOperand = true; |
| break; |
| } |
| } |
| } |
| |
| // TODO(b/175418767): Currently, MetaModel is not equipped to slice referenced subgraphs. If the |
| // original model is not compliant with the specified version and contains referenced subgraphs |
| // needed by the slice, return an invalidated slice. |
| if (someCompliantOperationHasASubgraphOperand) { |
| slice.mState = SliceState::INVALID; |
| return slice; |
| } |
| |
| // Map from an operand index in origOperands to the corresponding operand index in |
| // slicedOperands |
| std::map<uint32_t, uint32_t> origOperandIndexToSlicedIndex; |
| |
| // Collect the operand indexes of every operand that is an input to a |
| // compliant operation. If the operand is a CONSTANT_*, POINTER, or a |
| // NO_VALUE, copy it to the sliced model and update |
| // origOperandIndexToSlicedIndex accordingly. Otherwise, we'll deal with |
| // the operand in the subsequent "Main loop", where we process operation |
| // outputs (intermediates and model outputs). |
| std::set<uint32_t> inputOperandIndexesOfCompliantOperations; |
| for (uint32_t origOperationIndex = 0; origOperationIndex < origOperations.size(); |
| ++origOperationIndex) { |
| if (noncompliantOperations.count(origOperationIndex)) { |
| continue; |
| } |
| for (uint32_t input : origOperations[origOperationIndex].inputs) { |
| if (inputOperandIndexesOfCompliantOperations.insert(input).second) { |
| const Operand& origOperand = origOperands[input]; |
| switch (origOperand.lifetime) { |
| case Operand::LifeTime::CONSTANT_COPY: |
| case Operand::LifeTime::CONSTANT_REFERENCE: |
| case Operand::LifeTime::POINTER: |
| case Operand::LifeTime::NO_VALUE: { |
| const uint32_t slicedOperandIndex = |
| extend(&slicedOperands, origOperand).first; |
| origOperandIndexToSlicedIndex[input] = slicedOperandIndex; |
| VLOG(COMPILATION) << "origOperandIndexToSlicedIndex initialization created " |
| << input << " -> " << slicedOperandIndex << ": " |
| << slicedOperands[slicedOperandIndex]; |
| break; |
| } |
| default: |
| break; |
| } |
| } |
| } |
| } |
| |
| const auto [operandValuesSize, poolSizes] = getMemorySizes(mModel); |
| |
| OrigOperandToSlicedInputOperandIndex origOperandToSlicedInputOperandIndex( |
| &slicedOperands, &slice.mModel.main.inputIndexes, version, operandValuesSize, |
| poolSizes); |
| |
| // An input of the original model is an input of the sliced model if and |
| // only if it is consumed by at least one compliant operation. Note that in |
| // the sliced model we share all model inputs of the same "type"; and that |
| // we may later add model inputs to the sliced model. |
| for (uint32_t origInputIndex : mModel.main.inputIndexes) { |
| if (inputOperandIndexesOfCompliantOperations.count(origInputIndex)) { |
| const uint32_t slicedIndex = |
| origOperandToSlicedInputOperandIndex.getIndex(origOperands[origInputIndex]); |
| origOperandIndexToSlicedIndex[origInputIndex] = slicedIndex; |
| VLOG(COMPILATION) << "origOperandIndexToSlicedIndex inputIndexes processing created " |
| << origInputIndex << " -> " << slicedIndex << ": " |
| << slicedOperands[slicedIndex]; |
| } |
| } |
| |
| // Main loop: Process each operation of the original model. |
| processOperations(&slice, &origOperandIndexToSlicedIndex, &origOperandToSlicedInputOperandIndex, |
| noncompliantOperations, inputOperandIndexesOfCompliantOperations); |
| |
| // To keep things simple, we copy over these fields as-is. We could instead |
| // opt to regenerate them based on the operands present in the sliced model: |
| // This would be more complex and probably take more computation time, but |
| // it would reduce the size of the sliced model, and hence the time spent |
| // copying it around and potentially passing it across process boundaries. |
| slice.mModel.operandValues = mModel.operandValues; |
| slice.mModel.pools = mModel.pools; |
| |
| if (VLOG_IS_ON(COMPILATION)) { |
| { |
| std::ostringstream fromName; |
| fromName << "Slice: From canonical"; |
| graphDump(fromName.str().c_str(), mModel); |
| } |
| { |
| std::ostringstream toName; |
| toName << "Slice: To " << version; |
| graphDump(toName.str().c_str(), slice.mModel); |
| } |
| } |
| |
| slice.mState = invalid(slice.mModel, version, mStrictSlicing) ? SliceState::INVALID |
| : SliceState::NORMAL; |
| |
| return slice; |
| } |
| |
| } // namespace android::nn |