blob: b038698c0196677c277c215c2f6fed54938fd273 [file] [log] [blame]
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include <boost/test/unit_test.hpp>
#include "armnnTfParser/ITfParser.hpp"
#include "ParserPrototxtFixture.hpp"
BOOST_AUTO_TEST_SUITE(TensorflowParser)
struct ConcatOfConcatsFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
{
explicit ConcatOfConcatsFixture(const armnn::TensorShape& inputShape0, const armnn::TensorShape& inputShape1,
const armnn::TensorShape& inputShape2, const armnn::TensorShape& inputShape3,
unsigned int concatDim)
{
m_Prototext = R"(
node {
name: "graphInput0"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
}
}
}
}
node {
name: "graphInput1"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
}
}
}
}
node {
name: "graphInput2"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
}
}
}
}
node {
name: "graphInput3"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
}
}
}
}
node {
name: "Relu"
op: "Relu"
input: "graphInput0"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "Relu_1"
op: "Relu"
input: "graphInput1"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "Relu_2"
op: "Relu"
input: "graphInput2"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "Relu_3"
op: "Relu"
input: "graphInput3"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "concat/axis"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: )";
m_Prototext += std::to_string(concatDim);
m_Prototext += R"(
}
}
}
}
node {
name: "concat"
op: "ConcatV2"
input: "Relu"
input: "Relu_1"
input: "concat/axis"
attr {
key: "N"
value {
i: 2
}
}
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "Tidx"
value {
type: DT_INT32
}
}
}
node {
name: "concat_1/axis"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: )";
m_Prototext += std::to_string(concatDim);
m_Prototext += R"(
}
}
}
}
node {
name: "concat_1"
op: "ConcatV2"
input: "Relu_2"
input: "Relu_3"
input: "concat_1/axis"
attr {
key: "N"
value {
i: 2
}
}
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "Tidx"
value {
type: DT_INT32
}
}
}
node {
name: "concat_2/axis"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: )";
m_Prototext += std::to_string(concatDim);
m_Prototext += R"(
}
}
}
}
node {
name: "concat_2"
op: "ConcatV2"
input: "concat"
input: "concat_1"
input: "concat_2/axis"
attr {
key: "N"
value {
i: 2
}
}
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "Tidx"
value {
type: DT_INT32
}
}
}
)";
Setup({{ "graphInput0", inputShape0 },
{ "graphInput1", inputShape1 },
{ "graphInput2", inputShape2 },
{ "graphInput3", inputShape3}}, {"concat_2"});
}
};
struct ConcatOfConcatsFixtureNCHW : ConcatOfConcatsFixture
{
ConcatOfConcatsFixtureNCHW() : ConcatOfConcatsFixture({ 1, 1, 2, 2 }, { 1, 1, 2, 2 }, { 1, 1, 2, 2 },
{ 1, 1, 2, 2 }, 1 ) {}
};
struct ConcatOfConcatsFixtureNHWC : ConcatOfConcatsFixture
{
ConcatOfConcatsFixtureNHWC() : ConcatOfConcatsFixture({ 1, 1, 2, 2 }, { 1, 1, 2, 2 }, { 1, 1, 2, 2 },
{ 1, 1, 2, 2 }, 3 ) {}
};
BOOST_FIXTURE_TEST_CASE(ParseConcatOfConcatsNCHW, ConcatOfConcatsFixtureNCHW)
{
RunTest<4>({{"graphInput0", {0.0, 1.0, 2.0, 3.0}},
{"graphInput1", {4.0, 5.0, 6.0, 7.0}},
{"graphInput2", {8.0, 9.0, 10.0, 11.0}},
{"graphInput3", {12.0, 13.0, 14.0, 15.0}}},
{{"concat_2", { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0 }}});
}
BOOST_FIXTURE_TEST_CASE(ParseConcatOfConcatsNHWC, ConcatOfConcatsFixtureNHWC)
{
RunTest<4>({{"graphInput0", {0.0, 1.0, 2.0, 3.0}},
{"graphInput1", {4.0, 5.0, 6.0, 7.0}},
{"graphInput2", {8.0, 9.0, 10.0, 11.0}},
{"graphInput3", {12.0, 13.0, 14.0, 15.0}}},
{{"concat_2", { 0.0, 1.0, 4.0, 5.0, 8.0, 9.0, 12.0, 13.0,
2.0, 3.0, 6.0, 7.0, 10.0, 11.0, 14.0, 15.0 }}});
}
BOOST_AUTO_TEST_SUITE_END()