blob: f67a65cd09833cdf2a0befe24b2017c902f6380b [file] [log] [blame]
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
#include "TestLayerVisitor.hpp"
#include <boost/test/unit_test.hpp>
namespace armnn
{
void CheckLayerBindingId(LayerBindingId visitorId, LayerBindingId id)
{
BOOST_CHECK_EQUAL(visitorId, id);
}
// Concrete TestLayerVisitor subclasses for layers taking LayerBindingId argument with overridden VisitLayer methods
class TestInputLayerVisitor : public TestLayerVisitor
{
private:
LayerBindingId visitorId;
public:
explicit TestInputLayerVisitor(LayerBindingId id, const char* name = nullptr)
: TestLayerVisitor(name)
, visitorId(id)
{};
void VisitInputLayer(const IConnectableLayer* layer,
LayerBindingId id,
const char* name = nullptr) override
{
CheckLayerPointer(layer);
CheckLayerBindingId(visitorId, id);
CheckLayerName(name);
};
};
class TestOutputLayerVisitor : public TestLayerVisitor
{
private:
LayerBindingId visitorId;
public:
explicit TestOutputLayerVisitor(LayerBindingId id, const char* name = nullptr)
: TestLayerVisitor(name)
, visitorId(id)
{};
void VisitOutputLayer(const IConnectableLayer* layer,
LayerBindingId id,
const char* name = nullptr) override
{
CheckLayerPointer(layer);
CheckLayerBindingId(visitorId, id);
CheckLayerName(name);
};
};
} //namespace armnn