blob: dd73bd90a29adf8710cd7c715a14359d4787e922 [file] [log] [blame]
//
// Copyright © 2020 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include "armnnTfParser/ITfParser.hpp"
#include "ParserPrototxtFixture.hpp"
#include <boost/test/unit_test.hpp>
#include <PrototxtConversions.hpp>
BOOST_AUTO_TEST_SUITE(TensorflowParser)
namespace
{
std::string ConvertInt32VectorToOctalString(const std::vector<unsigned int>& data)
{
std::stringstream ss;
ss << "\"";
std::for_each(data.begin(), data.end(), [&ss](unsigned int d) {
ss << armnnUtils::ConvertInt32ToOctalString(static_cast<int>(d));
});
ss << "\"";
return ss.str();
}
} // namespace
struct TransposeFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
{
TransposeFixture(const armnn::TensorShape& inputShape,
const std::vector<unsigned int>& permuteVectorData)
{
using armnnUtils::ConvertTensorShapeToString;
armnn::TensorShape permuteVectorShape({static_cast<unsigned int>(permuteVectorData.size())});
m_Prototext = "node {\n"
" name: \"input\"\n"
" op: \"Placeholder\"\n"
" attr {\n"
" key: \"dtype\"\n"
" value {\n"
" type: DT_FLOAT\n"
" }\n"
" }\n"
" attr {\n"
" key: \"shape\"\n"
" value {\n"
" shape {\n";
m_Prototext.append(ConvertTensorShapeToString(inputShape));
m_Prototext.append(
" }\n"
" }\n"
" }\n"
"}\n"
"node {\n"
" name: \"transpose/perm\"\n"
" op: \"Const\"\n"
" attr {\n"
" key: \"dtype\"\n"
" value {\n"
" type: DT_INT32\n"
" }\n"
" }\n"
" attr {\n"
" key: \"value\"\n"
" value {\n"
" tensor {\n"
" dtype: DT_INT32\n"
" tensor_shape {\n"
);
m_Prototext.append(ConvertTensorShapeToString(permuteVectorShape));
m_Prototext.append(
" }\n"
" tensor_content: "
);
m_Prototext.append(ConvertInt32VectorToOctalString(permuteVectorData) + "\n");
m_Prototext.append(
" }\n"
" }\n"
" }\n"
"}\n"
);
m_Prototext.append(
"node {\n"
" name: \"output\"\n"
" op: \"Transpose\"\n"
" input: \"input\"\n"
" input: \"transpose/perm\"\n"
" attr {\n"
" key: \"T\"\n"
" value {\n"
" type: DT_FLOAT\n"
" }\n"
" }\n"
" attr {\n"
" key: \"Tperm\"\n"
" value {\n"
" type: DT_INT32\n"
" }\n"
" }\n"
"}\n"
);
Setup({{"input", inputShape}}, {"output"});
}
};
struct TransposeFixtureWithPermuteData : TransposeFixture
{
TransposeFixtureWithPermuteData()
: TransposeFixture({2, 2, 3, 4},
std::vector<unsigned int>({1, 3, 2, 0})) {}
};
BOOST_FIXTURE_TEST_CASE(TransposeWithPermuteData, TransposeFixtureWithPermuteData)
{
RunTest<4>(
{{"input", {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}}},
{{"output", {0, 24, 4, 28, 8, 32, 1, 25, 5, 29, 9, 33, 2, 26, 6,
30, 10, 34, 3, 27, 7, 31, 11, 35, 12, 36, 16, 40, 20, 44, 13, 37,
17, 41, 21, 45, 14, 38, 18, 42, 22, 46, 15, 39, 19, 43, 23, 47}}});
BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("output").second.GetShape()
== armnn::TensorShape({2, 4, 3, 2})));
}
struct TransposeFixtureWithoutPermuteData : TransposeFixture
{
// In case permute data is not given, it assumes (n-1,...,0) is given
// where n is the rank of input tensor.
TransposeFixtureWithoutPermuteData()
: TransposeFixture({2, 2, 3, 4},
std::vector<unsigned int>({3, 2, 1, 0})) {}
};
BOOST_FIXTURE_TEST_CASE(TransposeWithoutPermuteData, TransposeFixtureWithoutPermuteData)
{
RunTest<4>(
{{"input", {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}}},
{{"output", {0, 24, 12, 36, 4, 28, 16, 40, 8, 32, 20, 44, 1, 25,
13, 37, 5, 29, 17, 41, 9, 33, 21, 45, 2, 26, 14, 38, 6, 30, 18,
42,10, 34, 22, 46, 3, 27, 15, 39, 7, 31, 19, 43, 11, 35, 23, 47}}});
BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("output").second.GetShape()
== armnn::TensorShape({4, 3, 2, 2})));
}
BOOST_AUTO_TEST_SUITE_END()