blob: cfac691cff17efa60425eedb8a29a3d468a6aff7 [file] [log] [blame]
//
// Copyright © 2019 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
#include "SendCounterPacketTests.hpp"
#include <armnn/Logging.hpp>
#include <CommandHandlerFunctor.hpp>
#include <IProfilingConnection.hpp>
#include <ProfilingService.hpp>
#include <boost/polymorphic_cast.hpp>
#include <boost/test/unit_test.hpp>
#include <chrono>
#include <thread>
namespace armnn
{
namespace profiling
{
struct LogLevelSwapper
{
public:
LogLevelSwapper(armnn::LogSeverity severity)
{
// Set the new log level
armnn::ConfigureLogging(true, true, severity);
}
~LogLevelSwapper()
{
// The default log level for unit tests is "Fatal"
armnn::ConfigureLogging(true, true, armnn::LogSeverity::Fatal);
}
};
struct StreamRedirector
{
public:
StreamRedirector(std::ostream& stream, std::streambuf* newStreamBuffer)
: m_Stream(stream)
, m_BackupBuffer(m_Stream.rdbuf(newStreamBuffer))
{}
~StreamRedirector() { CancelRedirect(); }
void CancelRedirect()
{
// Only cancel the redirect once.
if (m_BackupBuffer != nullptr )
{
m_Stream.rdbuf(m_BackupBuffer);
m_BackupBuffer = nullptr;
}
}
private:
std::ostream& m_Stream;
std::streambuf* m_BackupBuffer;
};
class TestProfilingConnectionBase : public IProfilingConnection
{
public:
TestProfilingConnectionBase() = default;
~TestProfilingConnectionBase() = default;
bool IsOpen() const override { return true; }
void Close() override {}
bool WritePacket(const unsigned char* buffer, uint32_t length) override { return false; }
Packet ReadPacket(uint32_t timeout) override
{
// First time we're called return a connection ack packet. After that always timeout.
if (m_FirstCall)
{
m_FirstCall = false;
// Return connection acknowledged packet
return Packet(65536);
}
else
{
std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
throw armnn::TimeoutException("Simulate a timeout error\n");
}
}
bool m_FirstCall = true;
};
class TestProfilingConnectionTimeoutError : public TestProfilingConnectionBase
{
public:
TestProfilingConnectionTimeoutError()
: m_ReadRequests(0)
{}
Packet ReadPacket(uint32_t timeout) override
{
// Return connection acknowledged packet after three timeouts
if (m_ReadRequests % 3 == 0)
{
std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
++m_ReadRequests;
throw armnn::TimeoutException("Simulate a timeout error\n");
}
return Packet(65536);
}
int ReadCalledCount()
{
return m_ReadRequests.load();
}
private:
std::atomic<int> m_ReadRequests;
};
class TestProfilingConnectionArmnnError : public TestProfilingConnectionBase
{
public:
TestProfilingConnectionArmnnError()
: m_ReadRequests(0)
{}
Packet ReadPacket(uint32_t timeout) override
{
++m_ReadRequests;
throw armnn::Exception("Simulate a non-timeout error");
}
int ReadCalledCount()
{
return m_ReadRequests.load();
}
private:
std::atomic<int> m_ReadRequests;
};
class TestFunctorA : public CommandHandlerFunctor
{
public:
using CommandHandlerFunctor::CommandHandlerFunctor;
int GetCount() { return m_Count; }
void operator()(const Packet& packet) override
{
m_Count++;
}
private:
int m_Count = 0;
};
class TestFunctorB : public TestFunctorA
{
using TestFunctorA::TestFunctorA;
};
class TestFunctorC : public TestFunctorA
{
using TestFunctorA::TestFunctorA;
};
class SwapProfilingConnectionFactoryHelper : public ProfilingService
{
public:
using MockProfilingConnectionFactoryPtr = std::unique_ptr<MockProfilingConnectionFactory>;
SwapProfilingConnectionFactoryHelper()
: ProfilingService()
, m_MockProfilingConnectionFactory(new MockProfilingConnectionFactory())
, m_BackupProfilingConnectionFactory(nullptr)
{
BOOST_CHECK(m_MockProfilingConnectionFactory);
SwapProfilingConnectionFactory(ProfilingService::Instance(),
m_MockProfilingConnectionFactory.get(),
m_BackupProfilingConnectionFactory);
BOOST_CHECK(m_BackupProfilingConnectionFactory);
}
~SwapProfilingConnectionFactoryHelper()
{
BOOST_CHECK(m_BackupProfilingConnectionFactory);
IProfilingConnectionFactory* temp = nullptr;
SwapProfilingConnectionFactory(ProfilingService::Instance(),
m_BackupProfilingConnectionFactory,
temp);
}
MockProfilingConnection* GetMockProfilingConnection()
{
IProfilingConnection* profilingConnection = GetProfilingConnection(ProfilingService::Instance());
return boost::polymorphic_downcast<MockProfilingConnection*>(profilingConnection);
}
void ForceTransitionToState(ProfilingState newState)
{
TransitionToState(ProfilingService::Instance(), newState);
}
void WaitForProfilingPacketsSent(MockProfilingConnection* mockProfilingConnection, uint32_t timeout = 1000)
{
if (!mockProfilingConnection->HasWrittenData())
{
WaitForPacketSent(ProfilingService::Instance(), timeout);
// It's possible the wait has timed out. Check there is some data.
if (!mockProfilingConnection->HasWrittenData())
{
throw RuntimeException("ProfilingTests::WaitForProfilingPacketsSent timeout waiting for packet.");
}
}
}
private:
MockProfilingConnectionFactoryPtr m_MockProfilingConnectionFactory;
IProfilingConnectionFactory* m_BackupProfilingConnectionFactory;
};
} // namespace profiling
} // namespace armnn