IVGCVSW-4161 Provide for per model call back registration

!armnn:2810

Signed-off-by: Finn Williams <Finn.Williams@arm.com>
Change-Id: Idf56d42bd767baa5df0059a2f489f75281f8ac71
diff --git a/CMakeLists.txt b/CMakeLists.txt
index d5da0d3..3d0f518 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1002,7 +1002,7 @@
         tests/profiling/gatordmock/StreamMetadataCommandHandler.hpp
         )
 
-    include_directories(src/profiling tests/profiling tests/profiling/gatordmock)
+    include_directories(src/profiling tests/profiling tests/profiling/gatordmock src/timelineDecoder)
 
     add_library_ex(gatordMockService STATIC ${gatord_mock_sources})
     target_include_directories(gatordMockService PRIVATE src/armnnUtils)
@@ -1012,6 +1012,7 @@
 
     target_link_libraries(GatordMock
         armnn
+        timelineDecoder
         gatordMockService
         ${Boost_PROGRAM_OPTIONS_LIBRARY}
         ${Boost_SYSTEM_LIBRARY})
diff --git a/src/timelineDecoder/TimelineCaptureCommandHandler.cpp b/src/timelineDecoder/TimelineCaptureCommandHandler.cpp
index fb6935e..58edd9f 100644
--- a/src/timelineDecoder/TimelineCaptureCommandHandler.cpp
+++ b/src/timelineDecoder/TimelineCaptureCommandHandler.cpp
@@ -6,7 +6,7 @@
 #include "TimelineCaptureCommandHandler.hpp"
 
 #include <string>
-
+#include <armnn/Logging.hpp>
 namespace armnn
 {
 
@@ -28,7 +28,15 @@
     uint32_t offset = 0;
     m_PacketLength = packet.GetLength();
 
-    if ( m_PacketLength < 8 )
+    // We are expecting TimelineDirectoryCaptureCommandHandler to set the thread id size
+    // if it not set in the constructor
+    if (m_ThreadIdSize == 0)
+    {
+        ARMNN_LOG(error) << "TimelineCaptureCommandHandler: m_ThreadIdSize has not been set";
+        return;
+    }
+
+    if (packet.GetLength() < 8)
     {
         return;
     }
@@ -125,6 +133,11 @@
     m_TimelineDecoder.CreateEvent(event);
 }
 
+void TimelineCaptureCommandHandler::SetThreadIdSize(uint32_t size)
+{
+    m_ThreadIdSize = size;
+}
+
 void TimelineCaptureCommandHandler::operator()(const profiling::Packet& packet)
 {
     ParseData(packet);
diff --git a/src/timelineDecoder/TimelineCaptureCommandHandler.hpp b/src/timelineDecoder/TimelineCaptureCommandHandler.hpp
index b69e615..e143b5f 100644
--- a/src/timelineDecoder/TimelineCaptureCommandHandler.hpp
+++ b/src/timelineDecoder/TimelineCaptureCommandHandler.hpp
@@ -5,9 +5,9 @@
 
 #pragma once
 
-#include <CommandHandlerFunctor.hpp>
 #include "armnn/profiling/ITimelineDecoder.hpp"
 
+#include <CommandHandlerFunctor.hpp>
 #include <Packet.hpp>
 #include <ProfilingUtils.hpp>
 
@@ -31,11 +31,11 @@
                                   uint32_t packetId,
                                   uint32_t version,
                                   ITimelineDecoder& timelineDecoder,
-                                  uint32_t threadId_size)
-        : CommandHandlerFunctor(familyId, packetId, version),
-          m_TimelineDecoder(timelineDecoder),
-          m_ThreadIdSize(threadId_size),
-          m_PacketLength(0)
+                                  uint32_t threadIdSize = 0)
+        : CommandHandlerFunctor(familyId, packetId, version)
+        , m_TimelineDecoder(timelineDecoder)
+        , m_ThreadIdSize(threadIdSize)
+        , m_PacketLength(0)
     {}
 
     void operator()(const armnn::profiling::Packet& packet) override;
@@ -46,12 +46,13 @@
     void ReadRelationship(const unsigned char* data, uint32_t& offset);
     void ReadEvent(const unsigned char* data, uint32_t& offset);
 
+    void SetThreadIdSize(uint32_t size);
+
 private:
     void ParseData(const armnn::profiling::Packet& packet);
 
     ITimelineDecoder& m_TimelineDecoder;
-
-    const uint32_t            m_ThreadIdSize;
+    uint32_t m_ThreadIdSize;
     unsigned int              m_PacketLength;
     static const ReadFunction m_ReadFunctions[];
 
diff --git a/src/timelineDecoder/TimelineDecoder.cpp b/src/timelineDecoder/TimelineDecoder.cpp
index 2f9ac13..f7f4663 100644
--- a/src/timelineDecoder/TimelineDecoder.cpp
+++ b/src/timelineDecoder/TimelineDecoder.cpp
@@ -4,13 +4,14 @@
 //
 
 #include "TimelineDecoder.hpp"
-#include "../profiling/ProfilingUtils.hpp"
-
+#include <ProfilingUtils.hpp>
 #include <iostream>
+
 namespace armnn
 {
 namespace timelinedecoder
 {
+
 TimelineDecoder::TimelineStatus TimelineDecoder::CreateEntity(const Entity &entity)
 {
     if (m_OnNewEntityCallback == nullptr)
@@ -120,6 +121,34 @@
     return TimelineStatus::TimelineStatus_Success;
 }
 
+void TimelineDecoder::SetDefaultCallbacks()
+{
+    SetEntityCallback([](Model& model, const ITimelineDecoder::Entity entity)
+    {
+        model.m_Entities.emplace_back(entity);
+    });
+
+    SetEventClassCallback([](Model& model, const ITimelineDecoder::EventClass eventClass)
+    {
+        model.m_EventClasses.emplace_back(eventClass);
+    });
+
+    SetEventCallback([](Model& model, const ITimelineDecoder::Event event)
+    {
+        model.m_Events.emplace_back(event);
+    });
+
+    SetLabelCallback([](Model& model, const ITimelineDecoder::Label label)
+    {
+        model.m_Labels.emplace_back(label);
+    });
+
+    SetRelationshipCallback([](Model& model, const ITimelineDecoder::Relationship relationship)
+    {
+        model.m_Relationships.emplace_back(relationship);
+    });
+}
+
 void TimelineDecoder::print()
 {
     printLabels();
diff --git a/src/timelineDecoder/TimelineDecoder.hpp b/src/timelineDecoder/TimelineDecoder.hpp
index 4056731..c6d1e4e 100644
--- a/src/timelineDecoder/TimelineDecoder.hpp
+++ b/src/timelineDecoder/TimelineDecoder.hpp
@@ -39,13 +39,14 @@
 
     const Model& GetModel();
 
-
     TimelineStatus SetEntityCallback(const OnNewEntityCallback);
     TimelineStatus SetEventClassCallback(const OnNewEventClassCallback);
     TimelineStatus SetEventCallback(const OnNewEventCallback);
     TimelineStatus SetLabelCallback(const OnNewLabelCallback);
     TimelineStatus SetRelationshipCallback(const OnNewRelationshipCallback);
 
+    void SetDefaultCallbacks();
+
     void print();
 
 private:
diff --git a/src/timelineDecoder/TimelineDirectoryCaptureCommandHandler.cpp b/src/timelineDecoder/TimelineDirectoryCaptureCommandHandler.cpp
index 655e461..74aefea 100644
--- a/src/timelineDecoder/TimelineDirectoryCaptureCommandHandler.cpp
+++ b/src/timelineDecoder/TimelineDirectoryCaptureCommandHandler.cpp
@@ -4,6 +4,7 @@
 //
 
 #include "TimelineDirectoryCaptureCommandHandler.hpp"
+#include "TimelineCaptureCommandHandler.hpp"
 
 #include <iostream>
 #include <string>
@@ -41,6 +42,8 @@
     {
         m_SwTraceMessages.push_back(profiling::ReadSwTraceMessage(data, offset));
     }
+
+    m_TimelineCaptureCommandHandler.SetThreadIdSize(m_SwTraceHeader.m_ThreadIdBytes);
 }
 
 void TimelineDirectoryCaptureCommandHandler::Print()
diff --git a/src/timelineDecoder/TimelineDirectoryCaptureCommandHandler.hpp b/src/timelineDecoder/TimelineDirectoryCaptureCommandHandler.hpp
index b4e0fd2..a22a5d9 100644
--- a/src/timelineDecoder/TimelineDirectoryCaptureCommandHandler.hpp
+++ b/src/timelineDecoder/TimelineDirectoryCaptureCommandHandler.hpp
@@ -5,7 +5,8 @@
 
 #pragma once
 
-#include <CommandHandlerFunctor.hpp>
+
+#include <TimelineCaptureCommandHandler.hpp>
 #include <Packet.hpp>
 #include <PacketBuffer.hpp>
 #include <ProfilingUtils.hpp>
@@ -26,8 +27,10 @@
     TimelineDirectoryCaptureCommandHandler(uint32_t familyId,
                                            uint32_t packetId,
                                            uint32_t version,
+                                           TimelineCaptureCommandHandler& timelineCaptureCommandHandler,
                                            bool quietOperation = false)
         : CommandHandlerFunctor(familyId, packetId, version)
+        , m_TimelineCaptureCommandHandler(timelineCaptureCommandHandler)
         , m_QuietOperation(quietOperation)
     {}
 
@@ -40,6 +43,7 @@
     void ParseData(const armnn::profiling::Packet& packet);
     void Print();
 
+    TimelineCaptureCommandHandler& m_TimelineCaptureCommandHandler;
     bool m_QuietOperation;
 };
 
diff --git a/src/timelineDecoder/tests/TimelineTests.cpp b/src/timelineDecoder/tests/TimelineTests.cpp
index 62b4330..1f55515 100644
--- a/src/timelineDecoder/tests/TimelineTests.cpp
+++ b/src/timelineDecoder/tests/TimelineTests.cpp
@@ -83,8 +83,13 @@
 
     profiling::PacketVersionResolver packetVersionResolver;
 
+    TimelineDecoder timelineDecoder;
+    TimelineCaptureCommandHandler timelineCaptureCommandHandler(
+            1, 1, packetVersionResolver.ResolvePacketVersion(1, 1).GetEncodedValue(), timelineDecoder);
+
     TimelineDirectoryCaptureCommandHandler timelineDirectoryCaptureCommandHandler(
-            1, 0, packetVersionResolver.ResolvePacketVersion(1, 0).GetEncodedValue(), true);
+            1, 0, packetVersionResolver.ResolvePacketVersion(1, 0).GetEncodedValue(),
+            timelineCaptureCommandHandler, true);
 
     sendTimelinePacket->SendTimelineMessageDirectoryPackage();
     sendTimelinePacket->Commit();
@@ -151,6 +156,7 @@
     TimelineDecoder timelineDecoder;
     const TimelineDecoder::Model& model = timelineDecoder.GetModel();
 
+
     TimelineCaptureCommandHandler timelineCaptureCommandHandler(
         1, 1, packetVersionResolver.ResolvePacketVersion(1, 1).GetEncodedValue(), timelineDecoder, threadIdSize);
 
diff --git a/tests/profiling/gatordmock/GatordMockMain.cpp b/tests/profiling/gatordmock/GatordMockMain.cpp
index edad85c..e19461f 100644
--- a/tests/profiling/gatordmock/GatordMockMain.cpp
+++ b/tests/profiling/gatordmock/GatordMockMain.cpp
@@ -3,60 +3,67 @@
 // SPDX-License-Identifier: MIT
 //
 
-#include "../../../src/profiling/PacketVersionResolver.hpp"
-#include "../../../src/profiling/PeriodicCounterSelectionCommandHandler.hpp"
+#include "PacketVersionResolver.hpp"
 #include "CommandFileParser.hpp"
 #include "CommandLineProcessor.hpp"
 #include "DirectoryCaptureCommandHandler.hpp"
 #include "GatordMockService.hpp"
 #include "PeriodicCounterCaptureCommandHandler.hpp"
 #include "PeriodicCounterSelectionResponseHandler.hpp"
+#include <TimelineDecoder.hpp>
+#include <TimelineDirectoryCaptureCommandHandler.hpp>
+#include <TimelineCaptureCommandHandler.hpp>
 
 #include <iostream>
 #include <string>
+#include <NetworkSockets.hpp>
+#include <signal.h>
 
-int main(int argc, char* argv[])
+using namespace armnn;
+using namespace gatordmock;
+
+// Used to capture ctrl-c so we can close any remaining sockets before exit
+static volatile bool run = true;
+void exit_capture(int signum)
 {
-    // Process command line arguments
-    armnn::gatordmock::CommandLineProcessor cmdLine;
-    if (!cmdLine.ProcessCommandLine(argc, argv))
-    {
-        return EXIT_FAILURE;
-    }
+    IgnoreUnused(signum);
+    run = false;
+}
 
-    armnn::profiling::PacketVersionResolver packetVersionResolver;
+bool CreateMockService(armnnUtils::Sockets::Socket clientConnection, std::string commandFile, bool isEchoEnabled)
+{
+    profiling::PacketVersionResolver packetVersionResolver;
     // Create the Command Handler Registry
-    armnn::profiling::CommandHandlerRegistry registry;
+    profiling::CommandHandlerRegistry registry;
+
+    timelinedecoder::TimelineDecoder timelineDecoder;
+    timelineDecoder.SetDefaultCallbacks();
 
     // This functor will receive back the selection response packet.
-    armnn::gatordmock::PeriodicCounterSelectionResponseHandler periodicCounterSelectionResponseHandler(
-        0, 4, packetVersionResolver.ResolvePacketVersion(0, 4).GetEncodedValue());
+    PeriodicCounterSelectionResponseHandler periodicCounterSelectionResponseHandler(
+            0, 4, packetVersionResolver.ResolvePacketVersion(0, 4).GetEncodedValue());
     // This functor will receive the counter data.
-    armnn::gatordmock::PeriodicCounterCaptureCommandHandler counterCaptureCommandHandler(
-        3, 0, packetVersionResolver.ResolvePacketVersion(3, 0).GetEncodedValue());
+    PeriodicCounterCaptureCommandHandler counterCaptureCommandHandler(
+            3, 0, packetVersionResolver.ResolvePacketVersion(3, 0).GetEncodedValue());
 
-    armnn::profiling::DirectoryCaptureCommandHandler directoryCaptureCommandHandler(
-        0, 2, packetVersionResolver.ResolvePacketVersion(0, 2).GetEncodedValue(), false);
+    profiling::DirectoryCaptureCommandHandler directoryCaptureCommandHandler(
+            0, 2, packetVersionResolver.ResolvePacketVersion(0, 2).GetEncodedValue(), false);
+
+    timelinedecoder::TimelineCaptureCommandHandler timelineCaptureCommandHandler(
+            1, 1, packetVersionResolver.ResolvePacketVersion(1, 1).GetEncodedValue(), timelineDecoder);
+
+    timelinedecoder::TimelineDirectoryCaptureCommandHandler timelineDirectoryCaptureCommandHandler(
+            1, 0, packetVersionResolver.ResolvePacketVersion(1, 0).GetEncodedValue(),
+            timelineCaptureCommandHandler, false);
 
     // Register different derived functors
     registry.RegisterFunctor(&periodicCounterSelectionResponseHandler);
     registry.RegisterFunctor(&counterCaptureCommandHandler);
     registry.RegisterFunctor(&directoryCaptureCommandHandler);
+    registry.RegisterFunctor(&timelineDirectoryCaptureCommandHandler);
+    registry.RegisterFunctor(&timelineCaptureCommandHandler);
 
-    armnn::gatordmock::GatordMockService mockService(registry, cmdLine.IsEchoEnabled());
-
-    if (!mockService.OpenListeningSocket(cmdLine.GetUdsNamespace()))
-    {
-        return EXIT_FAILURE;
-    }
-    std::cout << "Bound to UDS namespace: \\0" << cmdLine.GetUdsNamespace() << std::endl;
-
-    // Wait for a single connection.
-    if (-1 == mockService.BlockForOneClient())
-    {
-        return EXIT_FAILURE;
-    }
-    std::cout << "Client connection established." << std::endl;
+    GatordMockService mockService(clientConnection, registry, isEchoEnabled);
 
     // Send receive the strweam metadata and send connection ack.
     if (!mockService.WaitForStreamMetaData())
@@ -69,11 +76,60 @@
     mockService.LaunchReceivingThread();
 
     // Process the SET and WAIT command from the file.
-    armnn::gatordmock::CommandFileParser commandLineParser;
-    commandLineParser.ParseFile(cmdLine.GetCommandFile(), mockService);
+    CommandFileParser commandLineParser;
+    commandLineParser.ParseFile(commandFile, mockService);
 
     // Once we've finished processing the file wait for the receiving thread to close.
     mockService.WaitForReceivingThread();
 
+    if(isEchoEnabled)
+    {
+        timelineDecoder.print();
+    }
+
     return EXIT_SUCCESS;
 }
+
+int main(int argc, char* argv[])
+{
+    // We need to capture ctrl-c so we can close any remaining sockets before exit
+    signal(SIGINT, exit_capture);
+
+    // Process command line arguments
+    CommandLineProcessor cmdLine;
+    if (!cmdLine.ProcessCommandLine(argc, argv))
+    {
+        return EXIT_FAILURE;
+    }
+
+    std::vector<std::thread> threads;
+    std::string commandFile = cmdLine.GetCommandFile();
+
+    armnnUtils::Sockets::Initialize();
+    armnnUtils::Sockets::Socket listeningSocket = socket(PF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0);
+
+    if (!GatordMockService::OpenListeningSocket(listeningSocket, cmdLine.GetUdsNamespace(), 10))
+    {
+        return EXIT_FAILURE;
+    }
+    std::cout << "Bound to UDS namespace: \\0" << cmdLine.GetUdsNamespace() << std::endl;
+
+    // make the socket non-blocking so we can exit the loop
+    armnnUtils::Sockets::SetNonBlocking(listeningSocket);
+    while (run)
+    {
+        armnnUtils::Sockets::Socket clientConnection =
+                armnnUtils::Sockets::Accept(listeningSocket, nullptr, nullptr, SOCK_CLOEXEC);
+
+        if (clientConnection > 0)
+        {
+            threads.emplace_back(
+                    std::thread(CreateMockService, clientConnection, commandFile, cmdLine.IsEchoEnabled()));
+        }
+
+        std::this_thread::sleep_for(std::chrono::milliseconds(100u));
+    }
+
+    armnnUtils::Sockets::Close(listeningSocket);
+    std::for_each(threads.begin(), threads.end(), [](std::thread& t){t.join();});
+}
\ No newline at end of file
diff --git a/tests/profiling/gatordmock/GatordMockService.cpp b/tests/profiling/gatordmock/GatordMockService.cpp
index c521196..a3f732c 100644
--- a/tests/profiling/gatordmock/GatordMockService.cpp
+++ b/tests/profiling/gatordmock/GatordMockService.cpp
@@ -24,11 +24,11 @@
 namespace gatordmock
 {
 
-bool GatordMockService::OpenListeningSocket(std::string udsNamespace)
+bool GatordMockService::OpenListeningSocket(armnnUtils::Sockets::Socket listeningSocket,
+                                            const std::string udsNamespace,
+                                            const int numOfConnections)
 {
-    Sockets::Initialize();
-    m_ListeningSocket = socket(PF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0);
-    if (-1 == m_ListeningSocket)
+    if (-1 == listeningSocket)
     {
         std::cerr << ": Socket construction failed: " << strerror(errno) << std::endl;
         return false;
@@ -41,13 +41,13 @@
     udsAddress.sun_family = AF_UNIX;
 
     // Bind the socket to the UDS namespace.
-    if (-1 == bind(m_ListeningSocket, reinterpret_cast<const sockaddr*>(&udsAddress), sizeof(sockaddr_un)))
+    if (-1 == bind(listeningSocket, reinterpret_cast<const sockaddr*>(&udsAddress), sizeof(sockaddr_un)))
     {
         std::cerr << ": Binding on socket failed: " << strerror(errno) << std::endl;
         return false;
     }
-    // Listen for 1 connection.
-    if (-1 == listen(m_ListeningSocket, 1))
+    // Listen for 10 connections.
+    if (-1 == listen(listeningSocket, numOfConnections))
     {
         std::cerr << ": Listen call on socket failed: " << strerror(errno) << std::endl;
         return false;
@@ -55,17 +55,6 @@
     return true;
 }
 
-Sockets::Socket GatordMockService::BlockForOneClient()
-{
-    m_ClientConnection = Sockets::Accept(m_ListeningSocket, nullptr, nullptr, SOCK_CLOEXEC);
-    if (-1 == m_ClientConnection)
-    {
-        std::cerr << ": Failure when waiting for a client connection: " << strerror(errno) << std::endl;
-        return -1;
-    }
-    return m_ClientConnection;
-}
-
 bool GatordMockService::WaitForStreamMetaData()
 {
     if (m_EchoPackets)
diff --git a/tests/profiling/gatordmock/GatordMockService.hpp b/tests/profiling/gatordmock/GatordMockService.hpp
index f91e902..c00685f 100644
--- a/tests/profiling/gatordmock/GatordMockService.hpp
+++ b/tests/profiling/gatordmock/GatordMockService.hpp
@@ -39,10 +39,13 @@
 public:
     /// @param registry reference to a command handler registry.
     /// @param echoPackets if true the raw packets will be printed to stdout.
-    GatordMockService(armnn::profiling::CommandHandlerRegistry& registry, bool echoPackets)
-        : m_HandlerRegistry(registry)
-        , m_EchoPackets(echoPackets)
-        , m_CloseReceivingThread(false)
+    GatordMockService(armnnUtils::Sockets::Socket clientConnection,
+                      armnn::profiling::CommandHandlerRegistry& registry,
+                      bool echoPackets)
+            : m_ClientConnection(clientConnection)
+            , m_HandlerRegistry(registry)
+            , m_EchoPackets(echoPackets)
+            , m_CloseReceivingThread(false)
     {
         m_PacketsReceivedCount.store(0, std::memory_order_relaxed);
     }
@@ -51,17 +54,14 @@
     {
         // We have set SOCK_CLOEXEC on these sockets but we'll close them to be good citizens.
         armnnUtils::Sockets::Close(m_ClientConnection);
-        armnnUtils::Sockets::Close(m_ListeningSocket);
     }
 
     /// Establish the Unix domain socket and set it to listen for connections.
     /// @param udsNamespace the namespace (socket address) associated with the listener.
     /// @return true only if the socket has been correctly setup.
-    bool OpenListeningSocket(std::string udsNamespace);
-
-    /// Block waiting to accept one client to connect to the UDS.
-    /// @return the file descriptor of the client connection.
-    armnnUtils::Sockets::Socket BlockForOneClient();
+    static bool OpenListeningSocket(armnnUtils::Sockets::Socket listeningSocket,
+                                    const std::string udsNamespace,
+                                    const int numOfConnections = 1);
 
     /// Once the connection is open wait to receive the stream meta data packet from the client. Reading this
     /// packet differs from others as we need to determine endianness.
@@ -118,6 +118,8 @@
 private:
     void ReceiveLoop(GatordMockService& mockService);
 
+    int MainLoop(armnn::profiling::CommandHandlerRegistry& registry, armnnUtils::Sockets::Socket m_ClientConnection);
+
     /// Block on the client connection until a complete packet has been received. This is a placeholder function to
     /// enable early testing of the tool.
     /// @return true if a valid packet has been received.
@@ -145,11 +147,10 @@
     uint32_t m_StreamMetaDataMaxDataLen;
     uint32_t m_StreamMetaDataPid;
 
+    armnnUtils::Sockets::Socket m_ClientConnection;
     armnn::profiling::CommandHandlerRegistry& m_HandlerRegistry;
 
     bool m_EchoPackets;
-    armnnUtils::Sockets::Socket m_ListeningSocket;
-    armnnUtils::Sockets::Socket m_ClientConnection;
     std::thread m_ListeningThread;
     std::atomic<bool> m_CloseReceivingThread;
 };
diff --git a/tests/profiling/gatordmock/tests/GatordMockTests.cpp b/tests/profiling/gatordmock/tests/GatordMockTests.cpp
index 78c6f11..bba8485 100644
--- a/tests/profiling/gatordmock/tests/GatordMockTests.cpp
+++ b/tests/profiling/gatordmock/tests/GatordMockTests.cpp
@@ -11,6 +11,7 @@
 #include <StreamMetadataCommandHandler.hpp>
 
 #include <TimelineDirectoryCaptureCommandHandler.hpp>
+#include <TimelineDecoder.hpp>
 
 #include <test/ProfilingMocks.hpp>
 
@@ -21,7 +22,7 @@
 BOOST_AUTO_TEST_SUITE(GatordMockTests)
 
 using namespace armnn;
-using namespace std::this_thread;    // sleep_for, sleep_until
+using namespace std::this_thread;
 using namespace std::chrono_literals;
 
 BOOST_AUTO_TEST_CASE(CounterCaptureHandlingTest)
@@ -118,6 +119,9 @@
     // Create the Command Handler Registry
     profiling::CommandHandlerRegistry registry;
 
+    timelinedecoder::TimelineDecoder timelineDecoder;
+    timelineDecoder.SetDefaultCallbacks();
+
     // Update with derived functors
     gatordmock::StreamMetadataCommandHandler streamMetadataCommandHandler(
         0, 0, packetVersionResolver.ResolvePacketVersion(0, 0).GetEncodedValue(), true);
@@ -128,18 +132,29 @@
     profiling::DirectoryCaptureCommandHandler directoryCaptureCommandHandler(
         0, 2, packetVersionResolver.ResolvePacketVersion(0, 2).GetEncodedValue(), true);
 
+    timelinedecoder::TimelineCaptureCommandHandler timelineCaptureCommandHandler(
+            1, 1, packetVersionResolver.ResolvePacketVersion(1, 1).GetEncodedValue(), timelineDecoder);
+
     timelinedecoder::TimelineDirectoryCaptureCommandHandler timelineDirectoryCaptureCommandHandler(
-        1, 0, packetVersionResolver.ResolvePacketVersion(1, 0).GetEncodedValue(), true);
+        1, 0, packetVersionResolver.ResolvePacketVersion(1, 0).GetEncodedValue(),
+        timelineCaptureCommandHandler, true);
 
     // Register different derived functors
     registry.RegisterFunctor(&streamMetadataCommandHandler);
     registry.RegisterFunctor(&counterCaptureCommandHandler);
     registry.RegisterFunctor(&directoryCaptureCommandHandler);
     registry.RegisterFunctor(&timelineDirectoryCaptureCommandHandler);
+
     // Setup the mock service to bind to the UDS.
     std::string udsNamespace = "gatord_namespace";
-    gatordmock::GatordMockService mockService(registry, false);
-    mockService.OpenListeningSocket(udsNamespace);
+
+    armnnUtils::Sockets::Initialize();
+    armnnUtils::Sockets::Socket listeningSocket = socket(PF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0);
+
+    if (!gatordmock::GatordMockService::OpenListeningSocket(listeningSocket, udsNamespace))
+    {
+        BOOST_FAIL("Failed to open Listening Socket");
+    }
 
     // Enable the profiling service.
     armnn::IRuntime::CreationOptions::ExternalProfilingOptions options;
@@ -154,12 +169,15 @@
     profilingService.Update();
 
     // Connect the profiling service to the mock Gatord.
-    int clientFd = mockService.BlockForOneClient();
-    if (-1 == clientFd)
+    armnnUtils::Sockets::Socket clientSocket =
+            armnnUtils::Sockets::Accept(listeningSocket, nullptr, nullptr, SOCK_CLOEXEC);
+    if (-1 == clientSocket)
     {
         BOOST_FAIL("Failed to connect client");
     }
 
+    gatordmock::GatordMockService mockService(clientSocket, registry, false);
+
     // Give the profiling service sending thread time start executing and send the stream metadata.
     while (profilingService.GetCurrentState() != profiling::ProfilingState::WaitingForAck)
     {
@@ -286,7 +304,7 @@
     mockService.WaitForReceivingThread();
     options.m_EnableProfiling = false;
     profilingService.ResetExternalProfilingOptions(options, true);
-
+    armnnUtils::Sockets::Close(listeningSocket);
     // Future tests here will add counters to the ProfilingService, increment values and examine
     // PeriodicCounterCapture data received. These are yet to be integrated.
 }