blob: 81d12829d5c9831081fc3eed1e91b4fff99f19f8 [file] [log] [blame]
/*
* Copyright (C) 2019 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#define LOG_TAG "MetaModel"
#include "MetaModel.h"
#include <algorithm>
#include <map>
#include <set>
#include <sstream>
#include <type_traits>
#include <utility>
#include <vector>
#include "GraphDump.h"
#include "HalInterfaces.h"
#include "Utils.h"
namespace android::nn {
namespace {
// Add an element to the end of the vector and return a pair consisting of the
// index of the new element and a pointer to the new element.
template <class T>
std::pair<uint32_t, T*> extend(hardware::hidl_vec<T>* vec) {
size_t nextIndex = vec->size();
vec->resize(nextIndex + 1);
return {nextIndex, &(*vec)[nextIndex]};
}
// Add an element to the end of the vector, set it to the specified value, and
// return a pair consisting of the index of the new element and a pointer to the
// new element.
template <class T>
std::pair<uint32_t, T*> extend(hardware::hidl_vec<T>* vec, const T& val) {
auto extended = extend(vec);
*extended.second = val;
return extended;
}
template <typename T>
bool operator<(const hardware::hidl_vec<T>& a, const hardware::hidl_vec<T>& b) {
return std::lexicographical_compare(a.begin(), a.end(), b.begin(), b.end());
}
// Compile-time mapping from a particular Model type to a name for that type.
template <class T_Model>
struct ModelVersion;
template <>
struct ModelVersion<V1_0::Model> {
static constexpr char name[] = "V1_0";
};
template <>
struct ModelVersion<V1_1::Model> {
static constexpr char name[] = "V1_1";
};
template <>
struct ModelVersion<V1_2::Model> {
static constexpr char name[] = "V1_2";
};
template <>
struct ModelVersion<V1_3::Model> {
static constexpr char name[] = "V1_3";
};
// Dispatcher mechanism for calling an appropriate uncheckedConvertToV1_*
// given the desired return type.
template <typename T_ReturnType>
T_ReturnType uncheckedConvertTo(OperationType type);
template <>
V1_0::OperationType uncheckedConvertTo<V1_0::OperationType>(OperationType type) {
return uncheckedConvertToV1_0(convertToV1_3(type));
}
template <>
V1_1::OperationType uncheckedConvertTo<V1_1::OperationType>(OperationType type) {
return uncheckedConvertToV1_1(convertToV1_3(type));
}
template <>
V1_2::OperationType uncheckedConvertTo<V1_2::OperationType>(OperationType type) {
return uncheckedConvertToV1_2(convertToV1_3(type));
}
// Dispatcher mechanism for calling an appropriate convertToV1_* given the
// desired return type. Note that there is no V1_1::Operand type.
template <typename T_ReturnType>
T_ReturnType convertTo(Operand operand);
template <>
V1_0::Operand convertTo<V1_0::Operand>(Operand operand) {
return convertToV1_0(convertToV1_3(operand));
}
template <>
V1_2::Operand convertTo<V1_2::Operand>(Operand operand) {
return convertToV1_2(convertToV1_3(operand));
}
// Dispatcher mechanism for calling an appropriate convertToV1_* given the
// desired return type. Note that there are no V1_[12]::Operand::LifeTime types.
template <typename T_ReturnType>
T_ReturnType convertTo(V1_3::OperandLifeTime lifetime);
template <>
V1_0::OperandLifeTime convertTo<V1_0::OperandLifeTime>(V1_3::OperandLifeTime lifetime) {
return convertToV1_0(lifetime);
}
// Dispatcher mechanism for calling an appropriate compliantWithV1_* given the
// desired target model type.
template <typename T_SlicedModel>
void getNoncompliantOperations(const V1_3::Model& model,
std::set<uint32_t>* noncompliantOperations);
template <>
void getNoncompliantOperations<V1_0::Model>(const V1_3::Model& model,
std::set<uint32_t>* noncompliantOperations) {
compliantWithV1_0(model, noncompliantOperations);
}
template <>
void getNoncompliantOperations<V1_1::Model>(const V1_3::Model& model,
std::set<uint32_t>* noncompliantOperations) {
compliantWithV1_1(model, noncompliantOperations);
}
template <>
void getNoncompliantOperations<V1_2::Model>(const V1_3::Model& model,
std::set<uint32_t>* noncompliantOperations) {
compliantWithV1_2(model, noncompliantOperations);
}
template <class T_SlicedModel>
bool invalid(const T_SlicedModel& model, bool strictSlicing) {
// A model must have at least one operation. However, it's possible that a
// slice has no operations (because no operations from the original model
// are compliant with the sliced model type). In this case, the sliced
// model would be invalid.
const bool looksEmpty = (model.operations.size() == 0);
if (strictSlicing) {
CHECK_EQ(looksEmpty, (model.operands.size() == 0));
}
if (looksEmpty) return true;
// A model must have at least one output. However, it's possible for a
// model to contain dead operations (i.e., outputs on which no model outputs
// are data dependent). A slice might contain only dead operations, and
// hence have no model outputs. In this case, the sliced model would be
// invalid.
if (model.outputIndexes.size() == 0) return true;
// We shouldn't have to check whether the model is valid.
// However, it could be invalid if:
// - there is an error in the slicing algorithm; or
// - there is an error in compliantWith (see http://b/131845106)
if (!validateModel(model)) {
LOG(WARNING) << "Sliced model fails validateModel()";
CHECK(!strictSlicing);
return true;
}
return false;
}
} // anonymous namespace
template <class T_SlicedModel>
MetaModel::ReturnedSlice<T_SlicedModel> MetaModel::getSlice(Slice<T_SlicedModel>* slice) const {
CHECK(slice != nullptr);
if (slice->mState == SliceState::UNINITIALIZED) {
*slice = makeSlice<T_SlicedModel>();
}
if (slice->mState == SliceState::INVALID) {
return {};
}
return MetaModel::ReturnedSlice<T_SlicedModel>(std::make_pair(
slice->mHidlModel, Mapper([slice](uint32_t slicedOperationIndex) {
return slice->mSlicedOperationIndexToOrigIndex.at(slicedOperationIndex);
})));
}
template MetaModel::ReturnedSlice<V1_0::Model> MetaModel::getSlice(Slice<V1_0::Model>* slice) const;
template MetaModel::ReturnedSlice<V1_1::Model> MetaModel::getSlice(Slice<V1_1::Model>* slice) const;
template MetaModel::ReturnedSlice<V1_2::Model> MetaModel::getSlice(Slice<V1_2::Model>* slice) const;
template <>
MetaModel::ReturnedSlice<V1_3::Model> MetaModel::getSlice(Slice<V1_3::Model>* slice) const {
CHECK(slice != nullptr);
if (slice->mState == SliceState::UNINITIALIZED) {
// When adding HAL version 1.4, make sure to handle control flow and referenced
// subgraphs here properly. A V1_3 sliced model should contain an IF/WHILE and
// its referenced subgraphs only if there are no V1_4+ operations in those
// subgraphs.
*slice = {
.mState = SliceState::NORMAL,
.mHidlModel = convertToV1_3(mModel),
};
}
Mapper trivialMapper = [](uint32_t i) { return i; };
return std::make_pair(slice->mHidlModel, trivialMapper);
}
// Utility class for makeSlice().
//
// For each output operand of a noncompliant operation that is the input
// operand of at least one compliant operation, we will ensure that there is
// a sliced model input whose "type" is that of the output operand. This is
// a map from operand "type" (in the original model) to model input
// operand index (in the sliced model). Unfortunately, there is no
// representation of operand "type" defined in the HAL that we can use
// naively here -- we want (OperandType, dimensions, scale, zeroPoint,
// extraParams), but these fields exist in Operand along with other fields
// that need to be excluded from the map key (numberOfConsumers, lifetime,
// location). There are several choices:
// - Don't have a map -- each output identified above gets its own sliced
// model input (no sharing of sliced model inputs).
// - Create an operand "type" representation solely for use as a map key.
// - Write a tailored comparison function that ignores the excluded fields.
// We choose to write a tailored comparison function. If Treble were to
// generate a comparison function for us (http://b/130567619) then it might
// be better to instead reset the excluded fields to canonical values --
// then we could use the Treble provided comparison function, and the
// solution would be robust (in a correctness sense, not a sharing sense) if
// more fields are added and we neglect to canonicalize them.
//
// We also use this map for model input operands of the original model that
// become input operands of the sliced model. This means that an original
// model input operand might be commoned with other original model input
// operands and/or with original model temporary operands.
template <typename T_SlicedOperand>
class MetaModel::OrigOperandToSlicedInputOperandIndex {
public:
OrigOperandToSlicedInputOperandIndex(hardware::hidl_vec<T_SlicedOperand>* slicedOperands,
hardware::hidl_vec<uint32_t>* slicedInputIndexes)
: mSlicedOperands(*slicedOperands), mSlicedInputIndexes(*slicedInputIndexes) {}
// Given an operand from the original model, return the index of the
// corresponding model input operand from the sliced model. Creates a
// new operand in the sliced model if necessary.
uint32_t getIndex(Operand operand) {
// Lookup
auto it = mMap.find(operand);
if (it != mMap.end()) {
VLOG(COMPILATION) << "OrigOperandToSlicedInputOperandIndex::getIndex looked for "
<< operand << " and found " << it->second << ": " << it->first;
return it->second;
}
// Create
operand.lifetime = Operand::LifeTime::SUBGRAPH_INPUT;
operand.location = {};
uint32_t slicedOperandIndex =
extend(&mSlicedOperands, convertTo<T_SlicedOperand>(operand)).first;
mMap[operand] = slicedOperandIndex;
extend(&mSlicedInputIndexes, slicedOperandIndex);
VLOG(COMPILATION) << "OrigOperandToSlicedInputOperandIndex::getIndex created "
<< slicedOperandIndex << ": " << operand;
return slicedOperandIndex;
}
private:
class Compare {
public:
bool operator()(const Operand& a, const Operand& b) const {
if (a.type != b.type) {
return a.type < b.type;
}
if (a.dimensions != b.dimensions) {
return a.dimensions < b.dimensions;
}
if (a.scale != b.scale) {
return a.scale < b.scale;
}
if (a.zeroPoint != b.zeroPoint) {
return a.zeroPoint < b.zeroPoint;
}
return compare(a.extraParams, b.extraParams);
}
private:
static bool compare(const Operand::SymmPerChannelQuantParams& a,
const Operand::SymmPerChannelQuantParams& b) {
if (a.scales != b.scales) {
return a.scales < b.scales;
}
return a.channelDim < b.channelDim;
}
static bool compare(const Operand::ExtraParams& a, const Operand::ExtraParams& b) {
if (a.index() != b.index()) {
return a.index() < b.index();
}
if (std::holds_alternative<Operand::SymmPerChannelQuantParams>(a)) {
return compare(std::get<Operand::SymmPerChannelQuantParams>(a),
std::get<Operand::SymmPerChannelQuantParams>(b));
}
if (std::holds_alternative<Operand::ExtensionParams>(a)) {
return compare(std::get<Operand::ExtensionParams>(a),
std::get<Operand::ExtensionParams>(b));
}
if (std::holds_alternative<Operand::NoParams>(a)) {
return false;
}
CHECK(false) << "Unexpected";
return false;
}
};
std::map<Operand, uint32_t, Compare> mMap;
hardware::hidl_vec<T_SlicedOperand>& mSlicedOperands;
hardware::hidl_vec<uint32_t>& mSlicedInputIndexes;
};
template <class T_SlicedModel>
void MetaModel::processOperations(
Slice<T_SlicedModel>* slice, std::map<uint32_t, uint32_t>* origOperandIndexToSlicedIndex,
OrigOperandToSlicedInputOperandIndex<typename Slice<T_SlicedModel>::Operand>*
origOperandToSlicedInputOperandIndex,
const std::set<uint32_t>& noncompliantOperations,
const std::set<uint32_t>& inputOperandIndexesOfCompliantOperations) const {
using SlicedOperand = typename Slice<T_SlicedModel>::Operand;
using SlicedOperation = typename Slice<T_SlicedModel>::Operation;
using SlicedOperationType = typename Slice<T_SlicedModel>::OperationType;
const auto& origOperands = mModel.main.operands;
const auto& origOperations = mModel.main.operations;
auto& slicedOperands = slice->mHidlModel.operands;
auto& slicedOperations = slice->mHidlModel.operations;
std::vector<uint32_t> origOperandNumberOfConsumers =
countNumberOfConsumers(origOperands.size(), origOperations);
for (uint32_t origOperationIndex = 0; origOperationIndex < origOperations.size();
++origOperationIndex) {
const Operation& origOperation = origOperations[origOperationIndex];
if (noncompliantOperations.count(origOperationIndex)) {
for (uint32_t output : origOperation.outputs) {
if (!inputOperandIndexesOfCompliantOperations.count(output)) {
continue;
}
const uint32_t slicedIndex =
origOperandToSlicedInputOperandIndex->getIndex(origOperands[output]);
(*origOperandIndexToSlicedIndex)[output] = slicedIndex;
VLOG(COMPILATION)
<< "origOperandIndexToSlicedIndex noncompliant output processing created "
<< output << " -> " << slicedIndex << ": "
<< toString(slicedOperands[slicedIndex]);
}
} else {
slice->mSlicedOperationIndexToOrigIndex.push_back(origOperationIndex);
SlicedOperation& slicedOperation = *extend(&slicedOperations).second;
CHECK_EQ(slice->mSlicedOperationIndexToOrigIndex.size(), slicedOperations.size());
slicedOperation.type = uncheckedConvertTo<SlicedOperationType>(origOperation.type);
// Model is topologically sorted, so all operation inputs must be
// present in origOperandIndexToSlicedIndex, and no operation
// outputs may be.
// Operation inputs
// - Fill in slicedOperation.inputs
// - Update number of consumers for each input operand
slicedOperation.inputs.resize(origOperation.inputs.size());
std::transform(
origOperation.inputs.begin(), origOperation.inputs.end(),
slicedOperation.inputs.begin(),
[&origOperandIndexToSlicedIndex, &slicedOperands](uint32_t origOperandIndex) {
uint32_t slicedOperandIndex =
origOperandIndexToSlicedIndex->at(origOperandIndex);
slicedOperands[slicedOperandIndex].numberOfConsumers++;
VLOG(COMPILATION) << "origOperandIndexToSlicedIndex compliant input "
"processing created "
<< origOperandIndex << " -> " << slicedOperandIndex
<< ": " << toString(slicedOperands[slicedOperandIndex]);
return slicedOperandIndex;
});
// Operation outputs
// - Add new operands to slicedOperands
// - Update origOperandIndexToSlicedIndex
// - Fill in slicedOperation.outputs
// - Record as a model output, if necessary
const uint32_t firstOutputSlicedOperandIndex = slicedOperands.size();
slicedOperands.resize(firstOutputSlicedOperandIndex + origOperation.outputs.size());
slicedOperation.outputs.resize(origOperation.outputs.size());
for (uint32_t outputNum = 0; outputNum < slicedOperation.outputs.size(); ++outputNum) {
uint32_t origOperandIndex = origOperation.outputs[outputNum];
uint32_t slicedOperandIndex = firstOutputSlicedOperandIndex + outputNum;
auto& slicedOperand = slicedOperands[slicedOperandIndex];
const auto& origOperand = origOperands[origOperandIndex];
slicedOperand = convertTo<SlicedOperand>(origOperand);
slicedOperand.numberOfConsumers = 0;
CHECK_EQ(origOperandIndexToSlicedIndex->count(origOperandIndex), size_t(0));
(*origOperandIndexToSlicedIndex)[origOperandIndex] = slicedOperandIndex;
slicedOperation.outputs[outputNum] = slicedOperandIndex;
const auto subgraphOutputLifetime = convertTo<decltype(slicedOperand.lifetime)>(
V1_3::OperandLifeTime::SUBGRAPH_OUTPUT);
if (!inputOperandIndexesOfCompliantOperations.count(origOperandIndex) &&
origOperandNumberOfConsumers[origOperandIndex] != 0) {
// Was consumed only by noncompliant operations; convert to
// an output of the sliced model.
slicedOperand.lifetime = subgraphOutputLifetime;
}
VLOG(COMPILATION) << "origOperandIndexToSlicedIndex compliant output created "
<< origOperandIndex << " -> " << slicedOperandIndex << ": "
<< toString(slicedOperand);
if (slicedOperand.lifetime == subgraphOutputLifetime) {
extend(&slice->mHidlModel.outputIndexes, slicedOperandIndex);
}
}
}
}
}
template <class T_SlicedModel>
MetaModel::Slice<T_SlicedModel> MetaModel::makeSlice() const {
using SlicedOperand = typename Slice<T_SlicedModel>::Operand;
Slice<T_SlicedModel> slice;
const auto& origOperands = mModel.main.operands;
const auto& origOperations = mModel.main.operations;
auto& slicedOperands = slice.mHidlModel.operands;
// Indexes of elements of noncompliant origOperations
std::set<uint32_t> noncompliantOperations;
getNoncompliantOperations<T_SlicedModel>(convertToV1_3(mModel), &noncompliantOperations);
// Map from an operand index in origOperands to the corresponding operand index in
// slicedOperands
std::map<uint32_t, uint32_t> origOperandIndexToSlicedIndex;
// Collect the operand indexes of every operand that is an input to a
// compliant operation. If the operand is a CONSTANT_*, POINTER, or a
// NO_VALUE, copy it to the sliced model and update
// origOperandIndexToSlicedIndex accordingly. Otherwise, we'll deal with
// the operand in the subsequent "Main loop", where we process operation
// outputs (intermediates and model outputs).
std::set<uint32_t> inputOperandIndexesOfCompliantOperations;
for (uint32_t origOperationIndex = 0; origOperationIndex < origOperations.size();
++origOperationIndex) {
if (noncompliantOperations.count(origOperationIndex)) {
continue;
}
for (uint32_t input : origOperations[origOperationIndex].inputs) {
if (inputOperandIndexesOfCompliantOperations.insert(input).second) {
const Operand& origOperand = origOperands[input];
switch (origOperand.lifetime) {
case Operand::LifeTime::CONSTANT_COPY:
case Operand::LifeTime::CONSTANT_REFERENCE:
case Operand::LifeTime::POINTER:
case Operand::LifeTime::NO_VALUE: {
const uint32_t slicedOperandIndex =
extend(&slicedOperands, convertTo<SlicedOperand>(origOperand))
.first;
slicedOperands[slicedOperandIndex].numberOfConsumers = 0;
origOperandIndexToSlicedIndex[input] = slicedOperandIndex;
VLOG(COMPILATION) << "origOperandIndexToSlicedIndex initialization created "
<< input << " -> " << slicedOperandIndex << ": "
<< toString(slicedOperands[slicedOperandIndex]);
break;
}
default:
break;
}
}
}
}
OrigOperandToSlicedInputOperandIndex origOperandToSlicedInputOperandIndex(
&slicedOperands, &slice.mHidlModel.inputIndexes);
// An input of the original model is an input of the sliced model if and
// only if it is consumed by at least one compliant operation. Note that in
// the sliced model we share all model inputs of the same "type"; and that
// we may later add model inputs to the sliced model.
for (uint32_t origInputIndex : mModel.main.inputIndexes) {
if (inputOperandIndexesOfCompliantOperations.count(origInputIndex)) {
const uint32_t slicedIndex =
origOperandToSlicedInputOperandIndex.getIndex(origOperands[origInputIndex]);
origOperandIndexToSlicedIndex[origInputIndex] = slicedIndex;
VLOG(COMPILATION) << "origOperandIndexToSlicedIndex inputIndexes processing created "
<< origInputIndex << " -> " << slicedIndex << ": "
<< toString(slicedOperands[slicedIndex]);
}
}
// Main loop: Process each operation of the original model.
processOperations(&slice, &origOperandIndexToSlicedIndex, &origOperandToSlicedInputOperandIndex,
noncompliantOperations, inputOperandIndexesOfCompliantOperations);
// To keep things simple, we copy over these fields as-is. We could instead
// opt to regenerate them based on the operands present in the sliced model:
// This would be more complex and probably take more computation time, but
// it would reduce the size of the sliced model, and hence the time spent
// copying it around and passing it across the HAL interface.
slice.mHidlModel.operandValues = convertToV1_0(mModel.operandValues);
slice.mHidlModel.pools = convertToV1_0(mModel.pools);
if (VLOG_IS_ON(COMPILATION)) {
{
std::ostringstream fromName;
fromName << "Slice: From canonical";
graphDump(fromName.str().c_str(), mModel);
}
{
std::ostringstream toName;
toName << "Slice: To " << ModelVersion<decltype(slice.mHidlModel)>::name;
graphDump(toName.str().c_str(), uncheckedConvert(convertToV1_3(slice.mHidlModel)));
}
}
slice.mState =
invalid(slice.mHidlModel, mStrictSlicing) ? SliceState::INVALID : SliceState::NORMAL;
return slice;
}
template MetaModel::Slice<V1_0::Model> MetaModel::makeSlice() const;
template MetaModel::Slice<V1_1::Model> MetaModel::makeSlice() const;
} // namespace android::nn