blob: c0037ae28f1ae3e78091cf351e9a13339b63a48e [file] [log] [blame]
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
#include "TestLayerVisitor.hpp"
namespace armnn
{
// Concrete TestLayerVisitor subclasses for layers taking Name argument with overridden VisitLayer methods
class TestAdditionLayerVisitor : public TestLayerVisitor
{
public:
explicit TestAdditionLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
void VisitAdditionLayer(const IConnectableLayer* layer,
const char* name = nullptr) override {
CheckLayerPointer(layer);
CheckLayerName(name);
};
};
class TestMultiplicationLayerVisitor : public TestLayerVisitor
{
public:
explicit TestMultiplicationLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
void VisitMultiplicationLayer(const IConnectableLayer* layer,
const char* name = nullptr) override {
CheckLayerPointer(layer);
CheckLayerName(name);
};
};
class TestFloorLayerVisitor : public TestLayerVisitor
{
public:
explicit TestFloorLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
void VisitFloorLayer(const IConnectableLayer* layer,
const char* name = nullptr) override {
CheckLayerPointer(layer);
CheckLayerName(name);
};
};
class TestDivisionLayerVisitor : public TestLayerVisitor
{
public:
explicit TestDivisionLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
void VisitDivisionLayer(const IConnectableLayer* layer,
const char* name = nullptr) override {
CheckLayerPointer(layer);
CheckLayerName(name);
};
};
class TestSubtractionLayerVisitor : public TestLayerVisitor
{
public:
explicit TestSubtractionLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
void VisitSubtractionLayer(const IConnectableLayer* layer,
const char* name = nullptr) override {
CheckLayerPointer(layer);
CheckLayerName(name);
};
};
class TestMaximumLayerVisitor : public TestLayerVisitor
{
public:
explicit TestMaximumLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
void VisitMaximumLayer(const IConnectableLayer* layer,
const char* name = nullptr) override {
CheckLayerPointer(layer);
CheckLayerName(name);
};
};
class TestMinimumLayerVisitor : public TestLayerVisitor
{
public:
explicit TestMinimumLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
void VisitMinimumLayer(const IConnectableLayer* layer,
const char* name = nullptr) override {
CheckLayerPointer(layer);
CheckLayerName(name);
};
};
class TestGreaterLayerVisitor : public TestLayerVisitor
{
public:
explicit TestGreaterLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
void VisitGreaterLayer(const IConnectableLayer* layer,
const char* name = nullptr) override {
CheckLayerPointer(layer);
CheckLayerName(name);
};
};
class TestEqualLayerVisitor : public TestLayerVisitor
{
public:
explicit TestEqualLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
void VisitEqualLayer(const IConnectableLayer* layer,
const char* name = nullptr) override {
CheckLayerPointer(layer);
CheckLayerName(name);
};
};
class TestRsqrtLayerVisitor : public TestLayerVisitor
{
public:
explicit TestRsqrtLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
void VisitRsqrtLayer(const IConnectableLayer* layer,
const char* name = nullptr) override {
CheckLayerPointer(layer);
CheckLayerName(name);
};
};
class TestGatherLayerVisitor : public TestLayerVisitor
{
public:
explicit TestGatherLayerVisitor(const char* name = nullptr) : TestLayerVisitor(name) {};
void VisitGatherLayer(const IConnectableLayer* layer,
const char* name = nullptr) override {
CheckLayerPointer(layer);
CheckLayerName(name);
};
};
} //namespace armnn