Less code duplication in HAL 1.2

Signed-off-by: Derek Lamberti <derek.lamberti@arm.com>
Change-Id: Ic2e8964745a4323efb1e06d466c0699f17a70c55
diff --git a/ArmnnDriverImpl.hpp b/ArmnnDriverImpl.hpp
index c5b1778..dfaafb3 100644
--- a/ArmnnDriverImpl.hpp
+++ b/ArmnnDriverImpl.hpp
@@ -23,6 +23,13 @@
 namespace armnn_driver
 {
 
+template <typename Callback, typename Context>
+struct CallbackContext
+{
+    Callback callback;
+    Context ctx;
+};
+
 template<typename HalPolicy>
 class ArmnnDriverImpl
 {
diff --git a/ArmnnPreparedModel.cpp b/ArmnnPreparedModel.cpp
index 2cd560d..d095e41 100644
--- a/ArmnnPreparedModel.cpp
+++ b/ArmnnPreparedModel.cpp
@@ -84,7 +84,8 @@
 namespace armnn_driver
 {
 template<typename HalVersion>
-RequestThread<ArmnnPreparedModel, HalVersion, ArmnnCallback_1_0> ArmnnPreparedModel<HalVersion>::m_RequestThread;
+RequestThread<ArmnnPreparedModel, HalVersion, CallbackContext_1_0>
+    ArmnnPreparedModel<HalVersion>::m_RequestThread;
 
 template<typename HalVersion>
 template <typename TensorBindingCollection>
@@ -226,7 +227,7 @@
         NotifyCallbackAndCheck(callback, errorStatus, callingFunction);
     };
 
-    ArmnnCallback_1_0 armnnCb;
+    CallbackContext_1_0 armnnCb;
     armnnCb.callback = cb;
     // post the request for asynchronous execution
     m_RequestThread.PostMsg(this, pMemPools, pInputTensors, pOutputTensors, armnnCb);
@@ -237,18 +238,18 @@
 template<typename HalVersion>
 void ArmnnPreparedModel<HalVersion>::ExecuteGraph(
         std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
-        std::shared_ptr<armnn::InputTensors>& pInputTensors,
-        std::shared_ptr<armnn::OutputTensors>& pOutputTensors,
-        ArmnnCallback_1_0 cb)
+        armnn::InputTensors& inputTensors,
+        armnn::OutputTensors& outputTensors,
+        CallbackContext_1_0 cb)
 {
     ALOGV("ArmnnPreparedModel::ExecuteGraph(...)");
 
-    DumpTensorsIfRequired("Input", *pInputTensors);
+    DumpTensorsIfRequired("Input", inputTensors);
 
     // run it
     try
     {
-        armnn::Status status = m_Runtime->EnqueueWorkload(m_NetworkId, *pInputTensors, *pOutputTensors);
+        armnn::Status status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors);
         if (status != armnn::Status::Success)
         {
             ALOGW("EnqueueWorkload failed");
@@ -269,7 +270,7 @@
         return;
     }
 
-    DumpTensorsIfRequired("Output", *pOutputTensors);
+    DumpTensorsIfRequired("Output", outputTensors);
 
     // Commit output buffers.
     // Note that we update *all* pools, even if they aren't actually used as outputs -
diff --git a/ArmnnPreparedModel.hpp b/ArmnnPreparedModel.hpp
index 270a933..89f6226 100644
--- a/ArmnnPreparedModel.hpp
+++ b/ArmnnPreparedModel.hpp
@@ -24,6 +24,10 @@
     armnnExecuteCallback_1_0 callback;
 };
 
+struct ExecutionContext_1_0 {};
+
+using CallbackContext_1_0 = CallbackContext<armnnExecuteCallback_1_0, ExecutionContext_1_0>;
+
 template <typename HalVersion>
 class ArmnnPreparedModel : public V1_0::IPreparedModel
 {
@@ -43,9 +47,9 @@
 
     /// execute the graph prepared from the request
     void ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
-                      std::shared_ptr<armnn::InputTensors>& pInputTensors,
-                      std::shared_ptr<armnn::OutputTensors>& pOutputTensors,
-                      ArmnnCallback_1_0 callback);
+                      armnn::InputTensors& inputTensors,
+                      armnn::OutputTensors& outputTensors,
+                      CallbackContext_1_0 callback);
 
     /// Executes this model with dummy inputs (e.g. all zeroes).
     /// \return false on failure, otherwise true
@@ -60,7 +64,7 @@
     HalModel                                                                m_Model;
     // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads
     // It is specific to this class, so it is declared as static here
-    static RequestThread<ArmnnPreparedModel, HalVersion, ArmnnCallback_1_0> m_RequestThread;
+    static RequestThread<ArmnnPreparedModel, HalVersion, CallbackContext_1_0> m_RequestThread;
     uint32_t                                                                m_RequestCount;
     const std::string&                                                      m_RequestInputsAndOutputsDumpDir;
     const bool                                                              m_GpuProfilingEnabled;
diff --git a/ArmnnPreparedModel_1_2.cpp b/ArmnnPreparedModel_1_2.cpp
index 9b79044..5031c5f 100644
--- a/ArmnnPreparedModel_1_2.cpp
+++ b/ArmnnPreparedModel_1_2.cpp
@@ -120,7 +120,7 @@
 {
 
 template<typename HalVersion>
-RequestThread<ArmnnPreparedModel_1_2, HalVersion, ArmnnCallback_1_2>
+RequestThread<ArmnnPreparedModel_1_2, HalVersion, CallbackContext_1_2>
         ArmnnPreparedModel_1_2<HalVersion>::m_RequestThread;
 
 template<typename HalVersion>
@@ -215,6 +215,160 @@
     return Execute(request, measureTiming, cb);
 }
 
+OutputShape ComputeShape(const armnn::TensorInfo& info)
+{
+    OutputShape shape;
+
+    hidl_vec<uint32_t> dimensions;
+
+    armnn::TensorShape tensorShape = info.GetShape();
+    const unsigned int numDims = tensorShape.GetNumDimensions();
+    dimensions.resize(numDims);
+
+    for (unsigned int outputIdx = 0u; outputIdx < numDims; ++outputIdx)
+    {
+        dimensions[outputIdx] = tensorShape[outputIdx];
+    }
+
+    shape.dimensions = dimensions;
+    shape.isSufficient = true;
+
+    return shape;
+}
+
+template<typename HalVersion>
+Return<V1_0::ErrorStatus> ArmnnPreparedModel_1_2<HalVersion>::PrepareMemoryForInputs(
+    armnn::InputTensors& inputs,
+    const V1_0::Request& request,
+    const std::vector<android::nn::RunTimePoolInfo>& memPools)
+{
+    inputs.reserve(request.inputs.size());
+    for (unsigned int i = 0; i < request.inputs.size(); i++)
+    {
+        const auto& inputArg = request.inputs[i];
+
+        const armnn::TensorInfo inputTensorInfo = m_Runtime->GetInputTensorInfo(m_NetworkId, i);
+        const armnn::Tensor inputTensor = GetTensorForRequestArgument(inputArg, inputTensorInfo, memPools);
+
+        if (inputTensor.GetMemoryArea() == nullptr)
+        {
+            ALOGE("Cannot execute request. Error converting request input %u to tensor", i);
+            return V1_0::ErrorStatus::GENERAL_FAILURE;
+        }
+
+        inputs.emplace_back(i, inputTensor);
+    }
+
+    return V1_0::ErrorStatus::NONE;
+}
+
+template<typename HalVersion>
+Return<V1_0::ErrorStatus> ArmnnPreparedModel_1_2<HalVersion>::PrepareMemoryForOutputs(
+    armnn::OutputTensors& outputs,
+    std::vector<OutputShape> &outputShapes,
+    const V1_0::Request& request,
+    const std::vector<android::nn::RunTimePoolInfo>& memPools)
+{
+    outputs.reserve(request.outputs.size());
+    for (unsigned int i = 0; i < request.outputs.size(); i++)
+    {
+        const auto& outputArg = request.outputs[i];
+
+        const armnn::TensorInfo outputTensorInfo = m_Runtime->GetOutputTensorInfo(m_NetworkId, i);
+        const armnn::Tensor outputTensor = GetTensorForRequestArgument(outputArg, outputTensorInfo, memPools);
+        if (outputTensor.GetMemoryArea() == nullptr)
+        {
+            ALOGE("Cannot execute request. Error converting request output %u to tensor", i);
+            return V1_0::ErrorStatus::GENERAL_FAILURE;
+        }
+
+        const size_t outputSize = outputTensorInfo.GetNumBytes();
+        const size_t bufferSize = memPools.at(outputArg.location.poolIndex).getHidlMemory().size();
+        if (bufferSize < outputSize)
+        {
+            ALOGW("ArmnnPreparedModel_1_2::Execute failed");
+            return V1_0::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE;
+        }
+
+        outputs.emplace_back(i, outputTensor);
+        outputShapes[i] = ComputeShape(outputTensorInfo);
+    }
+
+    return V1_0::ErrorStatus::NONE;
+}
+
+template<typename HalVersion>
+Return<V1_0::ErrorStatus> ArmnnPreparedModel_1_2<HalVersion>::PrepareMemoryForIO(
+                                         armnn::InputTensors& inputs,
+                                         armnn::OutputTensors& outputs,
+                                         std::vector<android::nn::RunTimePoolInfo>& memPools,
+                                         const V1_0::Request& request,
+                                         CallbackAsync_1_2 callback)
+{
+    if (!setRunTimePoolInfosFromHidlMemories(&memPools, request.pools))
+    {
+        callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute");
+        return V1_0::ErrorStatus::GENERAL_FAILURE;
+    }
+
+    // add the inputs and outputs with their data
+    try
+    {
+        if (PrepareMemoryForInputs(inputs, request, memPools) != V1_0::ErrorStatus::NONE)
+        {
+            callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute");
+            return V1_0::ErrorStatus::GENERAL_FAILURE;
+        }
+
+        std::vector<OutputShape> outputShapes(request.outputs.size());
+
+        auto errorStatus = PrepareMemoryForOutputs(outputs, outputShapes, request, memPools);
+        if (errorStatus != V1_0::ErrorStatus::NONE)
+        {
+            callback(errorStatus,
+                     outputShapes,
+                     g_NoTiming,
+                     "ArmnnPreparedModel_1_2::Execute");
+            return errorStatus;
+        }
+    }
+    catch (armnn::Exception& e)
+    {
+        ALOGW("armnn::Exception caught while preparing for EnqueueWorkload: %s", e.what());
+        callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute");
+        return V1_0::ErrorStatus::GENERAL_FAILURE;
+    }
+    catch (std::exception& e)
+    {
+        ALOGE("std::exception caught while preparing for EnqueueWorkload: %s", e.what());
+        callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute");
+        return V1_0::ErrorStatus::GENERAL_FAILURE;
+    }
+
+    return V1_0::ErrorStatus::NONE;
+}
+
+void CommitPools(std::vector<::android::nn::RunTimePoolInfo>& memPools)
+{
+    if (memPools.empty())
+    {
+        return;
+    }
+    // Commit output buffers.
+    // Note that we update *all* pools, even if they aren't actually used as outputs -
+    // this is simpler and is what the CpuExecutor does.
+    for (auto& pool : memPools)
+    {
+        // Type android::nn::RunTimePoolInfo has changed between Android P & Q and Android R, where
+        // update() has been removed and flush() added.
+#if defined(ARMNN_ANDROID_R) // Use the new Android implementation.
+        pool.flush();
+#else
+        pool.update();
+#endif
+    }
+}
+
 template<typename HalVersion>
 Return<void> ArmnnPreparedModel_1_2<HalVersion>::executeSynchronously(const V1_0::Request& request,
                                                                       MeasureTiming measureTiming,
@@ -229,7 +383,7 @@
         return Void();
     }
 
-    TimePoint driverStart, driverEnd, deviceStart, deviceEnd;
+    TimePoint driverStart;
 
     if (measureTiming == MeasureTiming::YES)
     {
@@ -243,167 +397,210 @@
         return Void();
     }
 
-    // allocate the tensors on the heap, as they are passed to the request thread
-    auto pInputTensors = std::make_shared<armnn::InputTensors>();
-    auto pOutputTensors = std::make_shared<armnn::OutputTensors>();
+    auto cbWrapper = [cb](V1_0::ErrorStatus errorStatus,
+                          std::vector<OutputShape> outputShapes,
+                          const Timing& timing,
+                          std::string)
+        {
+            cb(errorStatus, outputShapes, timing);
+        };
 
     // map the memory pool into shared pointers
     // use a shared memory pools vector on the heap, as it is passed to the request thread
-    auto pMemPools = std::make_shared<std::vector<android::nn::RunTimePoolInfo>>();
+    auto memPools = std::make_shared<std::vector<android::nn::RunTimePoolInfo>>();
 
-    if (!setRunTimePoolInfosFromHidlMemories(pMemPools.get(), request.pools))
+    // allocate the tensors on the heap, as they are passed to the request thread
+    auto inputs = std::make_shared<armnn::InputTensors>();
+    auto outputs = std::make_shared<armnn::OutputTensors>();
+
+    auto prepareStatus = PrepareMemoryForIO(*inputs, *outputs, *memPools, request, cbWrapper);
+    if (prepareStatus != V1_0::ErrorStatus::NONE)
     {
-        cb(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming);
-        return Void();
-    }
-    std::vector<OutputShape> outputShapes(request.outputs.size());
-
-    try
-    {
-        pInputTensors->reserve(request.inputs.size());
-        for (unsigned int i = 0; i < request.inputs.size(); i++)
-        {
-            const auto& inputArg = request.inputs[i];
-
-            const armnn::TensorInfo inputTensorInfo = m_Runtime->GetInputTensorInfo(m_NetworkId, i);
-            const armnn::Tensor inputTensor = GetTensorForRequestArgument(inputArg, inputTensorInfo, *pMemPools);
-
-            if (inputTensor.GetMemoryArea() == nullptr)
-            {
-                ALOGE("Cannot execute request. Error converting request input %u to tensor", i);
-                cb(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming);
-                return Void();
-            }
-
-            pInputTensors->emplace_back(i, inputTensor);
-        }
-        pOutputTensors->reserve(request.outputs.size());
-
-        for (unsigned int i = 0; i < request.outputs.size(); i++)
-        {
-            const auto& outputArg = request.outputs[i];
-
-            const armnn::TensorInfo outputTensorInfo = m_Runtime->GetOutputTensorInfo(m_NetworkId, i);
-            const armnn::Tensor outputTensor = GetTensorForRequestArgument(outputArg, outputTensorInfo, *pMemPools);
-
-            if (outputTensor.GetMemoryArea() == nullptr)
-            {
-                ALOGE("Cannot execute request. Error converting request output %u to tensor", i);
-                cb(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming);
-                return Void();
-            }
-            const size_t outputSize = outputTensorInfo.GetNumBytes();
-            const size_t bufferSize = pMemPools->at(outputArg.location.poolIndex).getHidlMemory().size();
-
-            hidl_vec<uint32_t> dimensions;
-
-            armnn::TensorShape tensorShape = outputTensorInfo.GetShape();
-            const unsigned int numDims = tensorShape.GetNumDimensions();
-            dimensions.resize(numDims);
-
-            for (unsigned int outputIdx = 0u; outputIdx < numDims; ++outputIdx)
-            {
-                dimensions[outputIdx] = tensorShape[outputIdx];
-            }
-            outputShapes[i].dimensions = dimensions;
-            outputShapes[i].isSufficient = bufferSize >= outputSize;
-
-            if (bufferSize < outputSize)
-            {
-                ALOGW("ArmnnPreparedModel_1_2::Execute failed");
-                cb(V1_0::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE, outputShapes, g_NoTiming);
-                return Void();
-            }
-
-            pOutputTensors->emplace_back(i, outputTensor);
-        }
-    }
-    catch (armnn::Exception& e)
-    {
-        ALOGW("armnn::Exception caught while preparing for EnqueueWorkload: %s", e.what());
-        cb(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming);
-        return Void();
-    }
-    catch (std::exception& e)
-    {
-        ALOGE("std::exception caught while preparing for EnqueueWorkload: %s", e.what());
-        cb(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming);
         return Void();
     }
 
     ALOGV("ArmnnPreparedModel_1_2::executeSynchronously() before Execution");
 
-    DumpTensorsIfRequired("Input", *pInputTensors);
+    CallbackContext_1_2 cbCtx;
+    cbCtx.callback = cbWrapper;
+    cbCtx.ctx.measureTimings = measureTiming;
+    cbCtx.ctx.driverStart = driverStart;
+    ExecuteGraph(memPools, *inputs, *outputs, cbCtx);
+
+    return Void();
+}
+
+template<typename HalVersion>
+template<typename CallbackContext>
+bool ArmnnPreparedModel_1_2<HalVersion>::ExecuteGraph(
+        std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
+        armnn::InputTensors& inputTensors,
+        armnn::OutputTensors& outputTensors,
+        CallbackContext cb)
+{
+    ALOGV("ArmnnPreparedModel_1_2::ExecuteGraph(...)");
+
+    TimePoint driverEnd, deviceStart, deviceEnd;
+
+    DumpTensorsIfRequired("Input", inputTensors);
+
+    std::vector<OutputShape> outputShapes(outputTensors.size());
+    for (unsigned int i = 0; i < outputTensors.size(); i++)
+    {
+        std::pair<int, armnn::Tensor> outputTensorPair = outputTensors[i];
+        const armnn::Tensor outputTensor = outputTensorPair.second;
+        const armnn::TensorInfo outputTensorInfo = outputTensor.GetInfo();
+
+        outputShapes[i] = ComputeShape(outputTensorInfo);
+    }
+
     // run it
     try
     {
-        if (measureTiming == MeasureTiming::YES)
+        if (cb.ctx.measureTimings == MeasureTiming::YES)
         {
             deviceStart = Now();
         }
 
-        armnn::Status status = m_Runtime->EnqueueWorkload(m_NetworkId, *pInputTensors, *pOutputTensors);
+        armnn::Status status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors);
 
-        if (measureTiming == MeasureTiming::YES)
+        if (cb.ctx.measureTimings == MeasureTiming::YES)
         {
             deviceEnd = Now();
         }
-
         if (status != armnn::Status::Success)
         {
             ALOGW("EnqueueWorkload failed");
-            cb(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming);
-            return Void();
+            cb.callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming,
+                    "ArmnnPreparedModel_1_2::ExecuteGraph");
+            return false;
         }
     }
     catch (armnn::Exception& e)
     {
-        ALOGW("armnn::Exception caught from EnqueueWorkload: %s", e.what());
-        cb(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming);
-        return Void();
+        ALOGW("armnn:Exception caught from EnqueueWorkload: %s", e.what());
+        cb.callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::ExecuteGraph");
+        return false;
     }
     catch (std::exception& e)
     {
         ALOGE("std::exception caught from EnqueueWorkload: %s", e.what());
-        cb(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming);
-        return Void();
+        cb.callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::ExecuteGraph");
+        return false;
     }
 
-    DumpTensorsIfRequired("Output", *pOutputTensors);
+    CommitPools(*pMemPools);
 
-    // Commit output buffers.
-    // Note that we update *all* pools, even if they aren't actually used as outputs -
-    // this is simpler and is what the CpuExecutor does.
-    for (android::nn::RunTimePoolInfo& pool : *pMemPools)
-    {
-        // Type android::nn::RunTimePoolInfo has changed between Android P & Q and Android R, where
-        // update() has been removed and flush() added.
-        #if defined(ARMNN_ANDROID_R) // Use the new Android implementation.
-            pool.flush();
-        #else
-            pool.update();
-        #endif
-    }
+    DumpTensorsIfRequired("Output", outputTensors);
 
-    ALOGV("ArmnnPreparedModel_1_2::executeSynchronously() after Execution");
-
-    if (measureTiming == MeasureTiming::YES)
+    if (cb.ctx.measureTimings == MeasureTiming::YES)
     {
         driverEnd = Now();
         Timing timing;
         timing.timeOnDevice = MicrosecondsDuration(deviceEnd, deviceStart);
-        timing.timeInDriver = MicrosecondsDuration(driverEnd, driverStart);
-        ALOGV("ArmnnPreparedModel_1_2::executeSynchronously timing Device = %lu Driver = %lu", timing.timeOnDevice,
-                timing.timeInDriver);
-        cb(V1_0::ErrorStatus::NONE, outputShapes, timing);
+        timing.timeInDriver = MicrosecondsDuration(driverEnd, cb.ctx.driverStart);
+        ALOGV("ArmnnPreparedModel_1_2::execute timing - Device = %lu Driver = %lu", timing.timeOnDevice,
+              timing.timeInDriver);
+        cb.callback(V1_0::ErrorStatus::NONE, outputShapes, timing, "ArmnnPreparedModel_1_2::ExecuteGraph");
+    } else {
+        cb.callback(V1_0::ErrorStatus::NONE, outputShapes, g_NoTiming, "ArmnnPreparedModel_1_2::ExecuteGraph");
     }
-    else
-    {
-        cb(V1_0::ErrorStatus::NONE, outputShapes, g_NoTiming);
-    }
-    return Void();
+
+    return true;
 }
 
+template<typename HalVersion>
+bool ArmnnPreparedModel_1_2<HalVersion>::ExecuteWithDummyInputs()
+{
+    std::vector<std::vector<char>> storage;
+    armnn::InputTensors inputTensors;
+    for (unsigned int i = 0; i < m_Model.inputIndexes.size(); i++)
+    {
+        const armnn::TensorInfo inputTensorInfo = m_Runtime->GetInputTensorInfo(m_NetworkId, i);
+        storage.emplace_back(inputTensorInfo.GetNumBytes());
+        const armnn::ConstTensor inputTensor(inputTensorInfo, storage.back().data());
+
+        inputTensors.emplace_back(i, inputTensor);
+    }
+
+    armnn::OutputTensors outputTensors;
+    for (unsigned int i = 0; i < m_Model.outputIndexes.size(); i++)
+    {
+        const armnn::TensorInfo outputTensorInfo = m_Runtime->GetOutputTensorInfo(m_NetworkId, i);
+        storage.emplace_back(outputTensorInfo.GetNumBytes());
+        const armnn::Tensor outputTensor(outputTensorInfo, storage.back().data());
+
+        outputTensors.emplace_back(i, outputTensor);
+    }
+
+    auto nullCallback = [](V1_0::ErrorStatus, std::vector<OutputShape>, const Timing&, std::string) {};
+    CallbackContext_1_2 callbackContext;
+    callbackContext.callback = nullCallback;
+    callbackContext.ctx.measureTimings = MeasureTiming::NO;
+    auto memPools = std::make_shared<std::vector<::android::nn::RunTimePoolInfo>>();
+    return ExecuteGraph(memPools,
+                        inputTensors,
+                        outputTensors,
+                        callbackContext);
+}
+
+template<typename HalVersion>
+Return <V1_0::ErrorStatus> ArmnnPreparedModel_1_2<HalVersion>::Execute(const V1_0::Request& request,
+                                                                       MeasureTiming measureTiming,
+                                                                       CallbackAsync_1_2 callback)
+{
+    ExecutionContext_1_2 ctx;
+    if (measureTiming == MeasureTiming::YES)
+    {
+        ctx.measureTimings = measureTiming;
+        ctx.driverStart = Now();
+    }
+
+    ALOGV("ArmnnPreparedModel_1_2::execute(): %s", GetModelSummary(m_Model).c_str());
+    m_RequestCount++;
+
+    if (!android::nn::validateRequest(request, m_Model))
+    {
+        callback(V1_0::ErrorStatus::INVALID_ARGUMENT, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute");
+        return V1_0::ErrorStatus::INVALID_ARGUMENT;
+    }
+
+    if (!m_RequestInputsAndOutputsDumpDir.empty())
+    {
+        ALOGD("Dumping inputs and outputs for request %" PRIuPTR, reinterpret_cast<std::uintptr_t>(&callback));
+    }
+
+    // map the memory pool into shared pointers
+    // use a shared memory pools vector on the heap, as it is passed to the request thread
+    auto memPools = std::make_shared<std::vector<android::nn::RunTimePoolInfo>>();
+
+    // allocate the tensors on the heap, as they are passed to the request thread
+    auto inputTensors = std::make_shared<armnn::InputTensors>();
+    auto outputTensors = std::make_shared<armnn::OutputTensors>();
+
+    auto prepareStatus = PrepareMemoryForIO(*inputTensors, *outputTensors, *memPools, request, callback);
+    switch(prepareStatus)
+    {
+        case V1_0::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
+            return V1_0::ErrorStatus::NONE;
+        case V1_0::ErrorStatus::GENERAL_FAILURE:
+            return V1_0::ErrorStatus::GENERAL_FAILURE;
+        default:
+        {}
+    }
+
+    ALOGV("ArmnnPreparedModel_1_2::execute(...) before PostMsg");
+
+    // post the request for asynchronous execution
+    CallbackContext_1_2 cb;
+    cb.callback = callback;
+    cb.ctx = ctx;
+    m_RequestThread.PostMsg(this, memPools, inputTensors, outputTensors, cb);
+    ALOGV("ArmnnPreparedModel_1_2::execute(...) after PostMsg");
+    return V1_0::ErrorStatus::NONE;
+}
+
+
 /// This class is strongly inspired by the default implementation in Android named DefaultBurstExecutorWithCache.
 /// The original code is licensed under Apache-2.0 and can be found at the following link:
 /// https://android.googlesource.com/platform/frameworks/
@@ -431,16 +628,16 @@
     }
 
     std::tuple<V1_0::ErrorStatus, hidl_vec<OutputShape>, Timing> execute(
-            const V1_0::Request& request, const std::vector<int32_t>& slots,
-            MeasureTiming measure) override
+        const V1_0::Request& request, const std::vector<int32_t>& slots,
+        MeasureTiming measure) override
     {
         ALOGV("ArmnnPreparedModel_1_2::BurstExecutorWithCache::execute");
         hidl_vec<hidl_memory> pools(slots.size());
 
         std::transform(slots.begin(), slots.end(), pools.begin(), [this](int32_t slot)
-        {
-            return m_MemoryCache[slot];
-        });
+            {
+                return m_MemoryCache[slot];
+            });
 
         V1_0::Request fullRequest = request;
         fullRequest.pools = std::move(pools);
@@ -452,11 +649,11 @@
         auto cb = [&returnedStatus, &returnedOutputShapes, &returnedTiming](V1_0::ErrorStatus status,
                                                                             const hidl_vec<OutputShape>& outputShapes,
                                                                             const Timing& timing)
-        {
-            returnedStatus = status;
-            returnedOutputShapes = outputShapes;
-            returnedTiming = timing;
-        };
+            {
+                returnedStatus = status;
+                returnedOutputShapes = outputShapes;
+                returnedTiming = timing;
+            };
 
         // Execute
         ALOGV("ArmnnPreparedModel_1_2::BurstExecutorWithCache executing");
@@ -474,17 +671,16 @@
     std::map<int, hidl_memory> m_MemoryCache;
 };
 
-
 template<typename HalVersion>
 Return<void> ArmnnPreparedModel_1_2<HalVersion>::configureExecutionBurst(
-        const sp<V1_2::IBurstCallback>& callback,
-        const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
-        const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
-        V1_2::IPreparedModel::configureExecutionBurst_cb cb)
+    const sp<V1_2::IBurstCallback>& callback,
+    const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
+    const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
+    V1_2::IPreparedModel::configureExecutionBurst_cb cb)
 {
     ALOGV("ArmnnPreparedModel_1_2::configureExecutionBurst");
     const std::shared_ptr<ArmnnBurstExecutorWithCache> executorWithCache =
-            std::make_shared<ArmnnBurstExecutorWithCache>(this);
+        std::make_shared<ArmnnBurstExecutorWithCache>(this);
     const sp<V1_2::IBurstContext> burst = ExecutionBurstServer::create(callback,
                                                                        requestChannel,
                                                                        resultChannel,
@@ -501,282 +697,15 @@
     return Void();
 }
 
-template<typename HalVersion>
-void ArmnnPreparedModel_1_2<HalVersion>::ExecuteGraph(
-        std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
-        std::shared_ptr<armnn::InputTensors>& pInputTensors,
-        std::shared_ptr<armnn::OutputTensors>& pOutputTensors,
-        ArmnnCallback_1_2 cb)
-{
-    ALOGV("ArmnnPreparedModel_1_2::ExecuteGraph(...)");
 
-    TimePoint driverEnd, deviceStart, deviceEnd;
-
-    DumpTensorsIfRequired("Input", *pInputTensors);
-
-    std::vector<std::pair<int, armnn::Tensor> > outputTensors = *pOutputTensors.get();
-    std::vector<OutputShape> outputShapes(outputTensors.size());
-
-    for (unsigned int i = 0; i < outputTensors.size(); i++)
-    {
-        std::pair<int, armnn::Tensor> outputTensorPair = outputTensors[i];
-        const armnn::Tensor outputTensor = outputTensorPair.second;
-        const armnn::TensorInfo outputTensorInfo = outputTensor.GetInfo();
-
-        hidl_vec<uint32_t> dimensions;
-
-        armnn::TensorShape tensorShape = outputTensorInfo.GetShape();
-        const unsigned int numDims = tensorShape.GetNumDimensions();
-        dimensions.resize(numDims);
-
-        for (unsigned int outputIdx = 0u; outputIdx < numDims; ++outputIdx)
-        {
-            dimensions[outputIdx] = tensorShape[outputIdx];
-        }
-        outputShapes[i].dimensions = dimensions;
-        outputShapes[i].isSufficient = true;
-    }
-
-    // run it
-    try
-    {
-        if (cb.measureTiming == MeasureTiming::YES)
-        {
-            deviceStart = Now();
-        }
-
-        armnn::Status status = m_Runtime->EnqueueWorkload(m_NetworkId, *pInputTensors, *pOutputTensors);
-
-        if (cb.measureTiming == MeasureTiming::YES)
-        {
-            deviceEnd = Now();
-        }
-        if (status != armnn::Status::Success)
-        {
-            ALOGW("EnqueueWorkload failed");
-            cb.callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming,
-                    "ArmnnPreparedModel_1_2::ExecuteGraph");
-            return;
-        }
-    }
-    catch (armnn::Exception& e)
-    {
-        ALOGW("armnn:Exception caught from EnqueueWorkload: %s", e.what());
-        cb.callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::ExecuteGraph");
-        return;
-    }
-    catch (std::exception& e)
-    {
-        ALOGE("std::exception caught from EnqueueWorkload: %s", e.what());
-        cb.callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::ExecuteGraph");
-        return;
-    }
-
-    DumpTensorsIfRequired("Output", *pOutputTensors);
-
-    // Commit output buffers.
-    // Note that we update *all* pools, even if they aren't actually used as outputs -
-    // this is simpler and is what the CpuExecutor does.
-    for (android::nn::RunTimePoolInfo& pool : *pMemPools)
-    {
-        // Type android::nn::RunTimePoolInfo has changed between Android P & Q and Android R, where
-        // update() has been removed and flush() added.
-        #if defined(ARMNN_ANDROID_R) // Use the new Android implementation.
-            pool.flush();
-        #else
-            pool.update();
-        #endif
-    }
-
-    if (cb.measureTiming == MeasureTiming::YES)
-    {
-        driverEnd = Now();
-        Timing timing;
-        timing.timeOnDevice = MicrosecondsDuration(deviceEnd, deviceStart);
-        timing.timeInDriver = MicrosecondsDuration(driverEnd, cb.driverStart);
-        cb.callback(V1_0::ErrorStatus::NONE, outputShapes, timing, "ExecuteGraph");
-    } else {
-        cb.callback(V1_0::ErrorStatus::NONE, outputShapes, g_NoTiming, "ExecuteGraph");
-    }
-}
-
-template<typename HalVersion>
-bool ArmnnPreparedModel_1_2<HalVersion>::ExecuteWithDummyInputs()
-{
-    std::vector<std::vector<char>> storage;
-    armnn::InputTensors inputTensors;
-    for (unsigned int i = 0; i < m_Model.inputIndexes.size(); i++)
-    {
-        const armnn::TensorInfo inputTensorInfo = m_Runtime->GetInputTensorInfo(m_NetworkId, i);
-        storage.emplace_back(inputTensorInfo.GetNumBytes());
-        const armnn::ConstTensor inputTensor(inputTensorInfo, storage.back().data());
-
-        inputTensors.emplace_back(i, inputTensor);
-    }
-
-    armnn::OutputTensors outputTensors;
-    for (unsigned int i = 0; i < m_Model.outputIndexes.size(); i++)
-    {
-        const armnn::TensorInfo outputTensorInfo = m_Runtime->GetOutputTensorInfo(m_NetworkId, i);
-        storage.emplace_back(outputTensorInfo.GetNumBytes());
-        const armnn::Tensor outputTensor(outputTensorInfo, storage.back().data());
-
-        outputTensors.emplace_back(i, outputTensor);
-    }
-
-    try
-    {
-        armnn::Status status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors);
-        if (status != armnn::Status::Success)
-        {
-            ALOGW("ExecuteWithDummyInputs: EnqueueWorkload failed");
-            return false;
-        }
-    }
-    catch (armnn::Exception& e)
-    {
-        ALOGW("ExecuteWithDummyInputs: armnn::Exception caught from EnqueueWorkload: %s", e.what());
-        return false;
-    }
-    catch (std::exception& e)
-    {
-        ALOGE("ExecuteWithDummyInputs: std::exception caught from EnqueueWorkload: %s", e.what());
-        return false;
-    }
-    return true;
-}
-
-template<typename HalVersion>
-Return <V1_0::ErrorStatus> ArmnnPreparedModel_1_2<HalVersion>::Execute(const V1_0::Request& request,
-                                                                       MeasureTiming measureTiming,
-                                                                       armnnExecuteCallback_1_2 callback)
-{
-    TimePoint driverStart;
-
-    if (measureTiming == MeasureTiming::YES)
-    {
-        driverStart = Now();
-    }
-
-    ALOGV("ArmnnPreparedModel_1_2::execute(): %s", GetModelSummary(m_Model).c_str());
-    m_RequestCount++;
-
-    if (!android::nn::validateRequest(request, m_Model))
-    {
-        callback(V1_0::ErrorStatus::INVALID_ARGUMENT, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute");
-        return V1_0::ErrorStatus::INVALID_ARGUMENT;
-    }
-
-    if (!m_RequestInputsAndOutputsDumpDir.empty())
-    {
-        ALOGD("Dumping inputs and outputs for request %" PRIuPTR, reinterpret_cast<std::uintptr_t>(&callback));
-    }
-
-    // allocate the tensors on the heap, as they are passed to the request thread
-    auto pInputTensors = std::make_shared<armnn::InputTensors>();
-    auto pOutputTensors = std::make_shared<armnn::OutputTensors>();
-
-    // map the memory pool into shared pointers
-    // use a shared memory pools vector on the heap, as it is passed to the request thread
-    auto pMemPools = std::make_shared<std::vector<android::nn::RunTimePoolInfo>>();
-
-    if (!setRunTimePoolInfosFromHidlMemories(pMemPools.get(), request.pools))
-    {
-        callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute");
-        return V1_0::ErrorStatus::GENERAL_FAILURE;
-    }
-
-    // add the inputs and outputs with their data
-    try
-    {
-        pInputTensors->reserve(request.inputs.size());
-        for (unsigned int i = 0; i < request.inputs.size(); i++)
-        {
-            const auto& inputArg = request.inputs[i];
-
-            const armnn::TensorInfo inputTensorInfo = m_Runtime->GetInputTensorInfo(m_NetworkId, i);
-            const armnn::Tensor inputTensor = GetTensorForRequestArgument(inputArg, inputTensorInfo, *pMemPools);
-
-            if (inputTensor.GetMemoryArea() == nullptr)
-            {
-                ALOGE("Cannot execute request. Error converting request input %u to tensor", i);
-                callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute");
-                return V1_0::ErrorStatus::GENERAL_FAILURE;
-            }
-
-            pInputTensors->emplace_back(i, inputTensor);
-        }
-
-        pOutputTensors->reserve(request.outputs.size());
-        std::vector<OutputShape> outputShapes(request.outputs.size());
-
-        for (unsigned int i = 0; i < request.outputs.size(); i++)
-        {
-            const auto& outputArg = request.outputs[i];
-
-            const armnn::TensorInfo outputTensorInfo = m_Runtime->GetOutputTensorInfo(m_NetworkId, i);
-            const armnn::Tensor outputTensor = GetTensorForRequestArgument(outputArg, outputTensorInfo, *pMemPools);
-            if (outputTensor.GetMemoryArea() == nullptr)
-            {
-                ALOGE("Cannot execute request. Error converting request output %u to tensor", i);
-                callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute");
-                return V1_0::ErrorStatus::GENERAL_FAILURE;
-            }
-
-            const size_t outputSize = outputTensorInfo.GetNumBytes();
-            const size_t bufferSize = pMemPools->at(outputArg.location.poolIndex).getHidlMemory().size();
-            pOutputTensors->emplace_back(i, outputTensor);
-
-            hidl_vec<uint32_t> dimensions;
-
-            armnn::TensorShape tensorShape = outputTensorInfo.GetShape();
-            const unsigned int numDims = tensorShape.GetNumDimensions();
-            dimensions.resize(numDims);
-
-            for (unsigned int outputIdx = 0u; outputIdx < numDims; ++outputIdx)
-            {
-                dimensions[outputIdx] = tensorShape[outputIdx];
-            }
-            outputShapes[i].dimensions = dimensions;
-            outputShapes[i].isSufficient = bufferSize >= outputSize;
-
-            if (bufferSize < outputSize)
-            {
-                ALOGW("ArmnnPreparedModel_1_2::Execute failed");
-                callback(V1_0::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE,
-                         outputShapes,
-                         g_NoTiming,
-                         "ArmnnPreparedModel_1_2::Execute");
-                return V1_0::ErrorStatus::NONE;
-            }
-        }
-    }
-    catch (armnn::Exception& e)
-    {
-        ALOGW("armnn::Exception caught while preparing for EnqueueWorkload: %s", e.what());
-        callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute");
-        return V1_0::ErrorStatus::GENERAL_FAILURE;
-    }
-    catch (std::exception& e)
-    {
-        ALOGE("std::exception caught while preparing for EnqueueWorkload: %s", e.what());
-        callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute");
-        return V1_0::ErrorStatus::GENERAL_FAILURE;
-    }
-
-    ALOGV("ArmnnPreparedModel_1_2::execute(...) before PostMsg");
-    // post the request for asynchronous execution
-    ArmnnCallback_1_2 armnnCb;
-    armnnCb.callback = callback;
-    armnnCb.measureTiming = measureTiming;
-    armnnCb.driverStart = driverStart;
-    m_RequestThread.PostMsg(this, pMemPools, pInputTensors, pOutputTensors, armnnCb);
-    ALOGV("ArmnnPreparedModel_1_2::execute(...) after PostMsg");
-    return V1_0::ErrorStatus::NONE;
-}
 
 #ifdef ARMNN_ANDROID_NN_V1_2
 template class ArmnnPreparedModel_1_2<hal_1_2::HalPolicy>;
+template bool ArmnnPreparedModel_1_2<hal_1_2::HalPolicy>::ExecuteGraph<CallbackContext_1_2>(
+        std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
+        armnn::InputTensors& pInputTensors,
+        armnn::OutputTensors& pOutputTensors,
+        CallbackContext_1_2 cb);
 #endif
 
 } // namespace armnn_driver
diff --git a/ArmnnPreparedModel_1_2.hpp b/ArmnnPreparedModel_1_2.hpp
index f609ef7..e68614a 100644
--- a/ArmnnPreparedModel_1_2.hpp
+++ b/ArmnnPreparedModel_1_2.hpp
@@ -19,18 +19,21 @@
 namespace armnn_driver
 {
 
-typedef std::function<void(::android::hardware::neuralnetworks::V1_0::ErrorStatus status,
-        std::vector<::android::hardware::neuralnetworks::V1_2::OutputShape> outputShapes,
-        const ::android::hardware::neuralnetworks::V1_2::Timing& timing,
-        std::string callingFunction)> armnnExecuteCallback_1_2;
+using CallbackAsync_1_2 = std::function<
+                                void(V1_0::ErrorStatus errorStatus,
+                                     std::vector<::android::hardware::neuralnetworks::V1_2::OutputShape> outputShapes,
+                                     const ::android::hardware::neuralnetworks::V1_2::Timing& timing,
+                                     std::string callingFunction)>;
 
-struct ArmnnCallback_1_2
+struct ExecutionContext_1_2
 {
-    armnnExecuteCallback_1_2 callback;
+    ::android::hardware::neuralnetworks::V1_2::MeasureTiming    measureTimings =
+        ::android::hardware::neuralnetworks::V1_2::MeasureTiming::NO;
     TimePoint driverStart;
-    MeasureTiming measureTiming;
 };
 
+using CallbackContext_1_2 = CallbackContext<CallbackAsync_1_2, ExecutionContext_1_2>;
+
 template <typename HalVersion>
 class ArmnnPreparedModel_1_2 : public V1_2::IPreparedModel
 {
@@ -62,19 +65,38 @@
             configureExecutionBurst_cb cb) override;
 
     /// execute the graph prepared from the request
-    void ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
-                      std::shared_ptr<armnn::InputTensors>& pInputTensors,
-                      std::shared_ptr<armnn::OutputTensors>& pOutputTensors,
-                      ArmnnCallback_1_2 callbackDescriptor);
+    template<typename CallbackContext>
+    bool ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
+                      armnn::InputTensors& inputTensors,
+                      armnn::OutputTensors& outputTensors,
+                      CallbackContext callback);
 
     /// Executes this model with dummy inputs (e.g. all zeroes).
     /// \return false on failure, otherwise true
     bool ExecuteWithDummyInputs();
 
 private:
-    Return <V1_0::ErrorStatus> Execute(const V1_0::Request& request,
-                                       MeasureTiming measureTiming,
-                                       armnnExecuteCallback_1_2 callback);
+    Return<V1_0::ErrorStatus> Execute(const V1_0::Request& request,
+                                      MeasureTiming measureTiming,
+                                      CallbackAsync_1_2 callback);
+
+    Return<V1_0::ErrorStatus> PrepareMemoryForInputs(
+            armnn::InputTensors& inputs,
+            const V1_0::Request& request,
+            const std::vector<android::nn::RunTimePoolInfo>& memPools);
+
+    Return<V1_0::ErrorStatus> PrepareMemoryForOutputs(
+            armnn::OutputTensors& outputs,
+            std::vector<OutputShape> &outputShapes,
+            const V1_0::Request& request,
+            const std::vector<android::nn::RunTimePoolInfo>& memPools);
+
+    Return <V1_0::ErrorStatus> PrepareMemoryForIO(
+            armnn::InputTensors& inputs,
+            armnn::OutputTensors& outputs,
+            std::vector<android::nn::RunTimePoolInfo>& memPools,
+            const V1_0::Request& request,
+            CallbackAsync_1_2 callback);
 
     template <typename TensorBindingCollection>
     void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings);
@@ -84,7 +106,9 @@
     V1_2::Model                                                                 m_Model;
     // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads
     // It is specific to this class, so it is declared as static here
-    static RequestThread<ArmnnPreparedModel_1_2, HalVersion, ArmnnCallback_1_2> m_RequestThread;
+    static RequestThread<ArmnnPreparedModel_1_2,
+                         HalVersion,
+                         CallbackContext_1_2>                                   m_RequestThread;
     uint32_t                                                                    m_RequestCount;
     const std::string&                                                          m_RequestInputsAndOutputsDumpDir;
     const bool                                                                  m_GpuProfilingEnabled;
diff --git a/RequestThread.cpp b/RequestThread.cpp
index 052c5c1..22a3ac3 100644
--- a/RequestThread.cpp
+++ b/RequestThread.cpp
@@ -21,15 +21,15 @@
 namespace armnn_driver
 {
 
-template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename Callback>
-RequestThread<PreparedModel, HalVersion, Callback>::RequestThread()
+template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext>
+RequestThread<PreparedModel, HalVersion, CallbackContext>::RequestThread()
 {
     ALOGV("RequestThread::RequestThread()");
     m_Thread = std::make_unique<std::thread>(&RequestThread::Process, this);
 }
 
-template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename Callback>
-RequestThread<PreparedModel, HalVersion, Callback>::~RequestThread()
+template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext>
+RequestThread<PreparedModel, HalVersion, CallbackContext>::~RequestThread()
 {
     ALOGV("RequestThread::~RequestThread()");
 
@@ -54,25 +54,25 @@
     catch (const std::exception&) { } // Swallow any exception.
 }
 
-template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename Callback>
-void RequestThread<PreparedModel, HalVersion, Callback>::PostMsg(PreparedModel<HalVersion>* model,
+template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext>
+void RequestThread<PreparedModel, HalVersion, CallbackContext>::PostMsg(PreparedModel<HalVersion>* model,
         std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& memPools,
         std::shared_ptr<armnn::InputTensors>& inputTensors,
         std::shared_ptr<armnn::OutputTensors>& outputTensors,
-        Callback callback)
+        CallbackContext callbackContext)
 {
     ALOGV("RequestThread::PostMsg(...)");
     auto data = std::make_shared<AsyncExecuteData>(model,
                                                    memPools,
                                                    inputTensors,
                                                    outputTensors,
-                                                   callback);
+                                                   callbackContext);
     auto pMsg = std::make_shared<ThreadMsg>(ThreadMsgType::REQUEST, data);
     PostMsg(pMsg);
 }
 
-template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename Callback>
-void RequestThread<PreparedModel, HalVersion, Callback>::PostMsg(std::shared_ptr<ThreadMsg>& pMsg)
+template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext>
+void RequestThread<PreparedModel, HalVersion, CallbackContext>::PostMsg(std::shared_ptr<ThreadMsg>& pMsg)
 {
     ALOGV("RequestThread::PostMsg(pMsg)");
     // Add a message to the queue and notify the request thread
@@ -81,8 +81,8 @@
     m_Cv.notify_one();
 }
 
-template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename Callback>
-void RequestThread<PreparedModel, HalVersion, Callback>::Process()
+template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext>
+void RequestThread<PreparedModel, HalVersion, CallbackContext>::Process()
 {
     ALOGV("RequestThread::Process()");
     while (true)
@@ -109,9 +109,9 @@
                 // invoke the asynchronous execution method
                 PreparedModel<HalVersion>* model = pMsg->data->m_Model;
                 model->ExecuteGraph(pMsg->data->m_MemPools,
-                                    pMsg->data->m_InputTensors,
-                                    pMsg->data->m_OutputTensors,
-                                    pMsg->data->m_Callback);
+                                    *(pMsg->data->m_InputTensors),
+                                    *(pMsg->data->m_OutputTensors),
+                                    pMsg->data->m_CallbackContext);
                 break;
             }
 
@@ -139,16 +139,16 @@
 /// Class template specializations
 ///
 
-template class RequestThread<ArmnnPreparedModel, hal_1_0::HalPolicy, ArmnnCallback_1_0>;
+template class RequestThread<ArmnnPreparedModel, hal_1_0::HalPolicy, CallbackContext_1_0>;
 
 #ifdef ARMNN_ANDROID_NN_V1_1
-template class RequestThread<armnn_driver::ArmnnPreparedModel, hal_1_1::HalPolicy, ArmnnCallback_1_0>;
+template class RequestThread<armnn_driver::ArmnnPreparedModel, hal_1_1::HalPolicy, CallbackContext_1_0>;
 #endif
 
 #ifdef ARMNN_ANDROID_NN_V1_2
-template class RequestThread<ArmnnPreparedModel, hal_1_1::HalPolicy, ArmnnCallback_1_0>;
-template class RequestThread<ArmnnPreparedModel, hal_1_2::HalPolicy, ArmnnCallback_1_0>;
-template class RequestThread<ArmnnPreparedModel_1_2, hal_1_2::HalPolicy, ArmnnCallback_1_2>;
+template class RequestThread<ArmnnPreparedModel, hal_1_1::HalPolicy, CallbackContext_1_0>;
+template class RequestThread<ArmnnPreparedModel, hal_1_2::HalPolicy, CallbackContext_1_0>;
+template class RequestThread<ArmnnPreparedModel_1_2, hal_1_2::HalPolicy, CallbackContext_1_2>;
 #endif
 
 } // namespace armnn_driver
diff --git a/RequestThread.hpp b/RequestThread.hpp
index 253d104..79f309a 100644
--- a/RequestThread.hpp
+++ b/RequestThread.hpp
@@ -21,7 +21,7 @@
 using TimePoint = std::chrono::steady_clock::time_point;
 static const TimePoint g_Min = std::chrono::steady_clock::time_point::min();
 
-template<template <typename HalVersion> class PreparedModel, typename HalVersion, typename Callback>
+template<template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext>
 class RequestThread
 {
 public:
@@ -41,7 +41,7 @@
                  std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& memPools,
                  std::shared_ptr<armnn::InputTensors>& inputTensors,
                  std::shared_ptr<armnn::OutputTensors>& outputTensors,
-                 Callback callback);
+                 CallbackContext callbackContext);
 
 private:
     RequestThread(const RequestThread&) = delete;
@@ -54,12 +54,12 @@
                          std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& memPools,
                          std::shared_ptr<armnn::InputTensors>& inputTensors,
                          std::shared_ptr<armnn::OutputTensors>& outputTensors,
-                         Callback callback)
+                         CallbackContext callbackContext)
             : m_Model(model)
             , m_MemPools(memPools)
             , m_InputTensors(inputTensors)
             , m_OutputTensors(outputTensors)
-            , m_Callback(callback)
+            , m_CallbackContext(callbackContext)
         {
         }
 
@@ -67,9 +67,8 @@
         std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> m_MemPools;
         std::shared_ptr<armnn::InputTensors> m_InputTensors;
         std::shared_ptr<armnn::OutputTensors> m_OutputTensors;
-        Callback m_Callback;
+        CallbackContext m_CallbackContext;
     };
-
     enum class ThreadMsgType
     {
         EXIT,                   // exit the thread