blob: b6d0c1a3a4ff8510d35a71861f2f30cd03dbd4da [file] [log] [blame]
/*
* Copyright (C) 2022 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#define LOG_TAG "ModelUtils"
#include "ModelUtils.h"
#include <android-base/logging.h>
#include <algorithm>
#include <numeric>
#include <unordered_set>
#include <utility>
#include <vector>
#include "nnapi/TypeUtils.h"
#include "nnapi/Types.h"
#include "nnapi/Validation.h"
namespace android::nn {
namespace {
// Map each `true` value in `includes` with a unique integer. `false` values are ignored. E.g.:
// includes = {false, true, true, false, true}
// returned = { X, 0, 1, X, 2}
std::vector<uint32_t> getMapping(const std::vector<bool>& includes) {
std::vector<uint32_t> mapping;
mapping.reserve(includes.size());
std::transform_exclusive_scan(includes.begin(), includes.end(), std::back_inserter(mapping), 0u,
std::plus<>{}, [](bool included) { return included ? 1u : 0u; });
return mapping;
}
// Remap indexes in `indexes` by the mapping `mapping`.
// Precondition: indexes != nullptr
void remapIndexes(std::vector<uint32_t>* indexes, const std::vector<uint32_t>& mapping) {
CHECK(indexes != nullptr);
for (uint32_t& index : (*indexes)) {
index = mapping.at(index);
}
}
// Keep elements from `elements` specified by `elementsToKeep`, removing all other elements.
// Precondition: elements != nullptr
// Precondition: elements->size() == elementsToKeep.size()
template <typename Type>
void keepSelectedElements(std::vector<Type>* elements, const std::vector<bool>& elementsToKeep) {
CHECK(elements != nullptr);
CHECK_EQ(elements->size(), elementsToKeep.size());
size_t elementsCopied = 0;
for (size_t i = 0; i < elementsToKeep.size(); ++i) {
if (elementsToKeep[i]) {
if (elementsCopied != i) {
(*elements)[elementsCopied] = std::move((*elements)[i]);
}
elementsCopied++;
}
}
elements->resize(elementsCopied);
}
// Find which operands in model.main.operands are read or written by model.main.operations and
// model.main.inputIndexes.
// Postcondition: returned.size() == model.main.operands.size()
std::vector<bool> identifyUsedOperands(const Model& model) {
std::vector<bool> used(model.main.operands.size(), false);
auto markUsed = [&used](const std::vector<uint32_t>& indexes) {
std::for_each(indexes.begin(), indexes.end(),
[&used](uint32_t index) { used.at(index) = true; });
};
for (const auto& operation : model.main.operations) {
markUsed(operation.inputs);
markUsed(operation.outputs);
}
markUsed(model.main.inputIndexes);
CHECK_EQ(used.size(), model.main.operands.size());
return used;
}
// Forward declaration.
void identifyUsedSubgraphs(uint32_t current, const std::vector<Model::Subgraph>& subgraphs,
std::vector<bool>* used);
// Helper function to find which subgraphs are reachable by `operands`.
// Precondition: used != nullptr
// Precondition: subgraphs.size() == used->size()
void identifyUsedSubgraphs(const std::vector<Operand>& operands,
const std::vector<Model::Subgraph>& subgraphs, std::vector<bool>* used) {
for (const auto& operand : operands) {
if (operand.lifetime == Operand::LifeTime::SUBGRAPH) {
identifyUsedSubgraphs(operand.location.offset, subgraphs, used);
}
}
}
// Helper function to find which subgraphs are reachable by the subgraph at the `current` index, and
// store when a subgraph is used in `used`. `used` also acts as a cache, ensuring each subgraph is
// processed at most once.
// Precondition: used != nullptr
// Precondition: subgraphs.size() == used->size()
// Precondition: current < subgraphs.size()
void identifyUsedSubgraphs(uint32_t current, const std::vector<Model::Subgraph>& subgraphs,
std::vector<bool>* used) {
CHECK(used != nullptr);
CHECK_EQ(subgraphs.size(), used->size());
CHECK_LT(current, subgraphs.size());
// If a subgraph was already marked as used, quickly return to avoid redundant processing.
if ((*used)[current]) {
return;
}
// Mark the current subgraph as used, then process any subgraph it references recursively.
(*used)[current] = true;
identifyUsedSubgraphs(subgraphs[current].operands, subgraphs, used);
}
// Find which subgraphs are reachable by the main operands of `model`.
// Postcondition: returned.size() == model.referenced.size()
std::vector<bool> identifyUsedSubgraphs(const Model& model) {
std::vector<bool> used(model.referenced.size(), false);
identifyUsedSubgraphs(model.main.operands, model.referenced, &used);
CHECK_EQ(used.size(), model.referenced.size());
return used;
}
// Helper function to find which pools are used by `subgraph`, and store when a pool is used in
// `used`.
// Precondition: used != nullptr
void identifyUsedPools(const Model::Subgraph& subgraph, std::vector<bool>* used) {
CHECK(used != nullptr);
for (const auto& operand : subgraph.operands) {
if (operand.lifetime == Operand::LifeTime::CONSTANT_REFERENCE) {
used->at(operand.location.poolIndex) = true;
}
}
}
// Find which pools are used by `model`.
// Postcondition: returned.size() == model.pools.size()
std::vector<bool> identifyUsedPools(const Model& model) {
std::vector<bool> used(model.pools.size(), false);
identifyUsedPools(model.main, &used);
for (const auto& subgraph : model.referenced) {
identifyUsedPools(subgraph, &used);
}
CHECK_EQ(used.size(), model.pools.size());
return used;
}
// Fix the DataLocation in `operand` by either remapping an index or by copying constant data.
// Precondition: operand != nullptr
// Precondition: newOperandValues != nullptr
void fixOperandDataLocation(Operand* operand, Model::OperandValues* newOperandValues,
const Model::OperandValues& oldOperandValues,
const std::vector<uint32_t>& remappedPoolIndex,
const std::vector<uint32_t>& remappedSubgraphIndex) {
CHECK(operand != nullptr);
CHECK(newOperandValues != nullptr);
switch (operand->lifetime) {
case Operand::LifeTime::CONSTANT_COPY: {
const uint8_t* data = oldOperandValues.data() + operand->location.offset;
const uint32_t length = operand->location.length;
operand->location = newOperandValues->append(data, length);
break;
}
case Operand::LifeTime::CONSTANT_REFERENCE:
operand->location.poolIndex = remappedPoolIndex.at(operand->location.poolIndex);
break;
case Operand::LifeTime::SUBGRAPH: {
uint32_t& subgraphIndex = operand->location.offset;
subgraphIndex = remappedSubgraphIndex.at(subgraphIndex);
break;
}
case Operand::LifeTime::TEMPORARY_VARIABLE:
case Operand::LifeTime::SUBGRAPH_INPUT:
case Operand::LifeTime::SUBGRAPH_OUTPUT:
case Operand::LifeTime::NO_VALUE:
case Operand::LifeTime::POINTER:
break;
}
}
// Fix all DataLocations in `operands` by either remapping an index or by copying constant data.
// Precondition: operands != nullptr
// Precondition: newOperandValues != nullptr
void fixOperandDataLocations(std::vector<Operand>* operands, Model::OperandValues* newOperandValues,
const Model::OperandValues& oldOperandValues,
const std::vector<uint32_t>& remappedPoolIndex,
const std::vector<uint32_t>& remappedSubgraphIndex) {
for (Operand& operand : (*operands)) {
fixOperandDataLocation(&operand, newOperandValues, oldOperandValues, remappedPoolIndex,
remappedSubgraphIndex);
}
}
// Fix all operands' DataLocations in `model` by either remapping an index or by copying constant
// data.
// Precondition: model != nullptr
void fixOperandDataLocations(Model* model, const std::vector<uint32_t>& remappedPoolIndex,
const std::vector<uint32_t>& remappedSubgraphIndex) {
const auto operandValues = std::exchange(model->operandValues, Model::OperandValues{});
fixOperandDataLocations(&model->main.operands, &model->operandValues, operandValues,
remappedPoolIndex, remappedSubgraphIndex);
for (auto& subgraph : model->referenced) {
fixOperandDataLocations(&subgraph.operands, &model->operandValues, operandValues,
remappedPoolIndex, remappedSubgraphIndex);
}
}
// Find which extensions are used in `model`.
// Postcondition: returned.size() == model.extensionNameToPrefix.size()
std::vector<bool> identifyUsedExtensions(const Model& model) {
std::unordered_set<uint16_t> prefixes;
const auto collectPrefix = [&prefixes](const auto& operandOrOperation) {
const auto prefix = getExtensionPrefix(static_cast<uint32_t>(operandOrOperation.type));
constexpr uint16_t kStandardPrefix = 0u;
if (prefix != kStandardPrefix) {
prefixes.insert(prefix);
}
};
const auto collectPrefixes = [collectPrefix](const Model::Subgraph& subgraph) {
std::for_each(subgraph.operands.begin(), subgraph.operands.end(), collectPrefix);
std::for_each(subgraph.operations.begin(), subgraph.operations.end(), collectPrefix);
};
collectPrefixes(model.main);
for (const auto& subgraph : model.referenced) {
collectPrefixes(subgraph);
}
std::vector<bool> used;
used.reserve(model.extensionNameToPrefix.size());
for (const auto& extension : model.extensionNameToPrefix) {
used.push_back(prefixes.count(extension.prefix) > 0);
}
CHECK_EQ(used.size(), model.extensionNameToPrefix.size());
return used;
}
} // anonymous namespace
void removeDeadOperands(Model* model) {
CHECK(model != nullptr);
// Keep only the operands which are used.
const auto operandsUsed = identifyUsedOperands(*model);
keepSelectedElements(&model->main.operands, operandsUsed);
// Fix operand indexes.
const auto mappedOperandIndices = getMapping(operandsUsed);
for (auto& operation : model->main.operations) {
remapIndexes(&operation.inputs, mappedOperandIndices);
remapIndexes(&operation.outputs, mappedOperandIndices);
}
remapIndexes(&model->main.inputIndexes, mappedOperandIndices);
remapIndexes(&model->main.outputIndexes, mappedOperandIndices);
// Keep only the subgraphs which are used.
const auto subgraphsUsed = identifyUsedSubgraphs(*model);
keepSelectedElements(&model->referenced, subgraphsUsed);
// Keep only the pools which are used.
const auto poolsUsed = identifyUsedPools(*model);
keepSelectedElements(&model->pools, poolsUsed);
// Fix operand locations.
const auto mappedPoolIndices = getMapping(poolsUsed);
const auto mappedSubgraphIndices = getMapping(subgraphsUsed);
fixOperandDataLocations(model, mappedPoolIndices, mappedSubgraphIndices);
// Keep only the extensionNameToPrefixes which are used.
const auto extensionsUsed = identifyUsedExtensions(*model);
keepSelectedElements(&model->extensionNameToPrefix, extensionsUsed);
}
} // namespace android::nn