blob: 818cb17a6533d95b03461a80ae3f92c969879c46 [file] [log] [blame]
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include "ModelAccuracyChecker.hpp"
#include <armnn/Logging.hpp>
#include <boost/filesystem.hpp>
#include <map>
#include <vector>
namespace armnnUtils
{
armnnUtils::ModelAccuracyChecker::ModelAccuracyChecker(const std::map<std::string, std::string>& validationLabels,
const std::vector<LabelCategoryNames>& modelOutputLabels)
: m_GroundTruthLabelSet(validationLabels)
, m_ModelOutputLabels(modelOutputLabels)
{}
float ModelAccuracyChecker::GetAccuracy(unsigned int k)
{
if (k > 10)
{
ARMNN_LOG(warning) << "Accuracy Tool only supports a maximum of Top 10 Accuracy. "
"Printing Top 10 Accuracy result!";
k = 10;
}
unsigned int total = 0;
for (unsigned int i = k; i > 0; --i)
{
total += m_TopK[i];
}
return static_cast<float>(total * 100) / static_cast<float>(m_ImagesProcessed);
}
// Split a string into tokens by a delimiter
std::vector<std::string>
SplitBy(const std::string& originalString, const std::string& delimiter, bool includeEmptyToken)
{
std::vector<std::string> tokens;
size_t cur = 0;
size_t next = 0;
while ((next = originalString.find(delimiter, cur)) != std::string::npos)
{
// Skip empty tokens, unless explicitly stated to include them.
if (next - cur > 0 || includeEmptyToken)
{
tokens.push_back(originalString.substr(cur, next - cur));
}
cur = next + delimiter.size();
}
// Get the remaining token
// Skip empty tokens, unless explicitly stated to include them.
if (originalString.size() - cur > 0 || includeEmptyToken)
{
tokens.push_back(originalString.substr(cur, originalString.size() - cur));
}
return tokens;
}
// Remove any preceding and trailing character specified in the characterSet.
std::string Strip(const std::string& originalString, const std::string& characterSet)
{
BOOST_ASSERT(!characterSet.empty());
const std::size_t firstFound = originalString.find_first_not_of(characterSet);
const std::size_t lastFound = originalString.find_last_not_of(characterSet);
// Return empty if the originalString is empty or the originalString contains only to-be-striped characters
if (firstFound == std::string::npos || lastFound == std::string::npos)
{
return "";
}
return originalString.substr(firstFound, lastFound + 1 - firstFound);
}
} // namespace armnnUtils