blob: f373e5669d8da9d115157afddc2094d6b337d153 [file] [log] [blame]
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include <boost/test/unit_test.hpp>
#include "armnnTfParser/ITfParser.hpp"
#include "ParserPrototxtFixture.hpp"
BOOST_AUTO_TEST_SUITE(TensorflowParser)
// Graph which tests that nodes are re-ordered in the queue when they are encountered a second time.
// In this case R0 will be encountered first via R1 and then via R2. At that time
// we need to make sure that R0 (and the I on which it is dependent) is moved to the front again
// so that it is before both R1 and R2.
// I
// |
// R0
// / \'
// R1 R2
// \ |
// \ R3
// \|
// O
struct RediscoveredDependenciesFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
{
RediscoveredDependenciesFixture()
{
// Input = tf.placeholder(tf.float32, 1, "input")
// Relu0 = tf.nn.relu(input, "relu0")
// Relu1 = tf.nn.relu(relu0, "relu1")
// Relu2 = tf.nn.relu(relu0, "relu2")
// Relu3 = tf.nn.relu(relu2, "relu3")
// Output = tf.add(relu1, relu3, "output")
m_Prototext = R"(
node {
name: "input"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 1
}
}
}
}
}
node {
name: "relu0"
op: "Relu"
input: "input"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "relu1"
op: "Relu"
input: "relu0"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "relu2"
op: "Relu"
input: "relu0"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "relu3"
op: "Relu"
input: "relu2"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "output"
op: "Add"
input: "relu1"
input: "relu3"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
)";
SetupSingleInputSingleOutput({ 1 }, "input", "output");
}
};
BOOST_FIXTURE_TEST_CASE(RediscoveredDependencies, RediscoveredDependenciesFixture)
{
RunTest<1>({1}, {2});
}
// Tests that a simple cycle in the tensorflow graph will be detected and an exception thrown, rather than the TfParser
// getting stuck in an infinite loop.
BOOST_AUTO_TEST_CASE(SimpleCycle)
{
const char* prototext = R"(
node {
name: "r1"
op: "Relu"
input: "r2"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "r2"
op: "Relu"
input: "r1"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
)";
armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create();
BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "r2" }), armnn::ParseException);
}
// Similar to the above SimpleCycle test, but has a single node which connects to itself.
BOOST_AUTO_TEST_CASE(SingleNodeCycle)
{
const char* prototext = R"(
node {
name: "r1"
op: "Relu"
input: "r1"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
)";
armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create();
BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "r1" }), armnn::ParseException);
}
// Similar to the above SimpleCycle test, but with a more complicated graph.
// I
// |
// A2---<---<-
// / \' |
// R1 R2 |
// \ | |
// \ R3 |
// \| |
// A1-->--->|
//
BOOST_AUTO_TEST_CASE(ComplexCycle)
{
// Input = tf.placeholder(tf.float32, 1, "input")
// Add2 = tf.nn.relu(input, add1, "add2") // This line won't actually run in TF, because add1 is not yet defined
// Relu1 = tf.nn.relu(relu0, "relu1")
// Relu2 = tf.nn.relu(relu0, "relu2")
// Relu3 = tf.nn.relu(relu2, "relu3")
// Add1 = tf.add(relu1, relu3, "add1")
const char* prototext = R"(
node {
name: "input"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 1
}
}
}
}
}
node {
name: "add2"
op: "Add"
input: "input"
input: "add1"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "relu1"
op: "Relu"
input: "add2"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "relu2"
op: "Relu"
input: "add2"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "relu3"
op: "Relu"
input: "relu2"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "add1"
op: "Add"
input: "relu1"
input: "relu3"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
)";
armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create();
BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "add1" }), armnn::ParseException);
}
// Tests that a graph with an input that is not present throws a ParseException.
BOOST_AUTO_TEST_CASE(InvalidInput)
{
const char* prototext = R"(
node {
name: "r1"
op: "Relu"
input: "a-node-that-does-not-exist"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
)";
armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create();
BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "r1" }), armnn::ParseException);
}
BOOST_AUTO_TEST_SUITE_END()