blob: 3a97630ccfe097f778bd0090ef2787ece7615946 [file] [log] [blame]
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
#include <map>
#include "QuantizationInput.hpp"
#include "armnn/LayerVisitorBase.hpp"
#include "armnn/Tensor.hpp"
namespace armnnQuantizer
{
/// QuantizationDataSet is a structure which is created after parsing a quantization CSV file.
/// It contains records of filenames which contain refinement data per pass ID for binding ID.
class QuantizationDataSet
{
using QuantizationInputs = std::vector<armnnQuantizer::QuantizationInput>;
public:
using iterator = QuantizationInputs::iterator;
using const_iterator = QuantizationInputs::const_iterator;
QuantizationDataSet();
QuantizationDataSet(std::string csvFilePath);
~QuantizationDataSet();
bool IsEmpty() const {return m_QuantizationInputs.empty();}
iterator begin() { return m_QuantizationInputs.begin(); }
iterator end() { return m_QuantizationInputs.end(); }
const_iterator begin() const { return m_QuantizationInputs.begin(); }
const_iterator end() const { return m_QuantizationInputs.end(); }
const_iterator cbegin() const { return m_QuantizationInputs.cbegin(); }
const_iterator cend() const { return m_QuantizationInputs.cend(); }
private:
void ParseCsvFile();
QuantizationInputs m_QuantizationInputs;
std::string m_CsvFilePath;
};
/// Visitor class implementation to gather the TensorInfo for LayerBindingID for creation of ConstTensor for Refine.
class InputLayerVisitor : public armnn::LayerVisitorBase<armnn::VisitorNoThrowPolicy>
{
public:
void VisitInputLayer(const armnn::IConnectableLayer *layer, armnn::LayerBindingId id, const char* name);
armnn::TensorInfo GetTensorInfo(armnn::LayerBindingId);
private:
std::map<armnn::LayerBindingId, armnn::TensorInfo> m_TensorInfos;
};
} // namespace armnnQuantizer