IVGCVSW-2997 Refactor reference LSTM workload

Signed-off-by: Nattapat Chaimanowong <nattapat.chaimanowong@arm.com>
Change-Id: I6883f878d9f701a55153292769d2fc0530d2529e
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index ca9d7d9..61e0d40 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -860,6 +860,18 @@
 {
     ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "LstmQueueDescriptor", 2, "input");
     ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "LstmQueueDescriptor", 2, "output");
+
+    std::vector<DataType> supportedTypes = {
+        DataType::Float32,
+    };
+
+    ValidateDataTypes(workloadInfo.m_InputTensorInfos[0],
+                      supportedTypes,
+                      "LstmQueueDescriptor");
+
+    ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0],
+                      supportedTypes,
+                      "LstmQueueDescriptor");
 }
 
 void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 1b1f0ce..67c13c3 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -594,12 +594,6 @@
                                       const TensorInfo* cellToOutputWeights,
                                       Optional<std::string&> reasonIfUnsupported) const
 {
-    ignore_unused(outputStateIn);
-    ignore_unused(cellStateIn);
-    ignore_unused(scratchBuffer);
-    ignore_unused(outputStateOut);
-    ignore_unused(cellStateOut);
-    ignore_unused(output);
     ignore_unused(descriptor);
     ignore_unused(inputToForgetWeights);
     ignore_unused(inputToCellWeights);
@@ -618,10 +612,35 @@
     ignore_unused(projectionBias);
     ignore_unused(cellToForgetWeights);
     ignore_unused(cellToOutputWeights);
-    return IsSupportedForDataTypeRef(reasonIfUnsupported,
-                                     input.GetDataType(),
-                                     &TrueFunc<>,
-                                     &FalseFuncU8<>);
+
+    bool supported = true;
+
+    std::array<DataType,2> supportedTypes = {
+        DataType::Float32
+    };
+
+    supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
+                                  "Reference Lstm: input is not a supported type.");
+
+    supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
+                                  "Reference Lstm: input and outputStateIn types are mismatched");
+
+    supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
+                                  "Reference Lstm: input and cellStateIn types are mismatched");
+
+    supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
+                                  "Reference Lstm: input and scratchBuffer types are mismatched");
+
+    supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
+                                  "Reference Lstm: input and outputStateOut types are mismatched");
+
+    supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
+                                  "Reference Lstm: input and cellStateOut types are mismatched");
+
+    supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
+                                  "Reference Lstm: input and output types are mismatched");
+
+    return supported;
 }
 
 bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 8887bb7..6603aaf 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -274,7 +274,7 @@
 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateLstm(const LstmQueueDescriptor& descriptor,
     const WorkloadInfo& info) const
 {
-    return MakeWorkload<RefLstmFloat32Workload, NullWorkload>(descriptor, info);
+    return std::make_unique<RefLstmWorkload>(descriptor, info);
 }
 
 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConvertFp16ToFp32(
diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk
index 06459ed..5034c0f 100644
--- a/src/backends/reference/backend.mk
+++ b/src/backends/reference/backend.mk
@@ -47,7 +47,7 @@
         workloads/RefFullyConnectedUint8Workload.cpp \
         workloads/RefGatherWorkload.cpp \
         workloads/RefL2NormalizationFloat32Workload.cpp \
-        workloads/RefLstmFloat32Workload.cpp \
+        workloads/RefLstmWorkload.cpp \
         workloads/RefMeanFloat32Workload.cpp \
         workloads/RefMeanUint8Workload.cpp \
         workloads/RefMergerFloat32Workload.cpp \
diff --git a/src/backends/reference/workloads/Activation.cpp b/src/backends/reference/workloads/Activation.cpp
index 760c9a0..2b0c84e 100644
--- a/src/backends/reference/workloads/Activation.cpp
+++ b/src/backends/reference/workloads/Activation.cpp
@@ -88,26 +88,16 @@
                 float a,
                 float b)
 {
-    for (size_t i = 0; i<tensorInfo.GetNumElements(); i++)
+    unsigned int numElements = tensorInfo.GetNumElements();
+
+    for (unsigned int i = 0; i < numElements; i++)
     {
         out.Set(Activation(in.Get(), function, a, b));
-
         ++in;
         ++out;
     }
-}
-
-void Activation(const float* in,
-               float* out,
-               const TensorInfo& tensorInfo,
-               ActivationFunction function,
-               float a,
-               float b)
-{
-    for (size_t i = 0; i<tensorInfo.GetNumElements(); i++)
-    {
-        out[i] = Activation(in[i], function, a, b);
-    }
+    in -= numElements;
+    out -= numElements;
 }
 
 } //namespace armnn
diff --git a/src/backends/reference/workloads/Activation.hpp b/src/backends/reference/workloads/Activation.hpp
index ffe3c5f..b7fd50c 100644
--- a/src/backends/reference/workloads/Activation.hpp
+++ b/src/backends/reference/workloads/Activation.hpp
@@ -22,11 +22,4 @@
                 float a,
                 float b);
 
-// This is still used by Reference LSTM implementation
-void Activation(const float* in,
-                float* out,
-                const TensorInfo& tensorInfo,
-                ActivationFunction function,
-                float a,
-                float b);
 } //namespace armnn
diff --git a/src/backends/reference/workloads/BaseIterator.hpp b/src/backends/reference/workloads/BaseIterator.hpp
index 3439e41..97af95a 100644
--- a/src/backends/reference/workloads/BaseIterator.hpp
+++ b/src/backends/reference/workloads/BaseIterator.hpp
@@ -29,8 +29,6 @@
 class Decoder : public BaseIterator
 {
 public:
-    using InterfaceType = IType;
-
     Decoder() {}
 
     virtual ~Decoder() {}
@@ -42,13 +40,13 @@
 class Encoder : public BaseIterator
 {
 public:
-    using InterfaceType = IType;
-
     Encoder() {}
 
     virtual ~Encoder() {}
 
     virtual void Set(IType right) = 0;
+
+    virtual IType Get() const = 0;
 };
 
 template<typename T, typename Base>
@@ -77,6 +75,7 @@
         return *this;
     }
 
+protected:
     T* m_Iterator;
 };
 
@@ -135,6 +134,11 @@
         *m_Iterator = armnn::Quantize<uint8_t>(right, m_Scale, m_Offset);
     }
 
+    float Get() const override
+    {
+        return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
+    }
+
 private:
     const float m_Scale;
     const int32_t m_Offset;
@@ -151,6 +155,11 @@
         *m_Iterator = armnn::Quantize<int16_t>(right, m_Scale, m_Offset);
     }
 
+    float Get() const override
+    {
+        return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
+    }
+
 private:
     const float m_Scale;
     const int32_t m_Offset;
@@ -166,6 +175,11 @@
     {
         *m_Iterator = right;
     }
+
+    float Get() const override
+    {
+        return *m_Iterator;
+    }
 };
 
 class BooleanEncoder : public TypedIterator<uint8_t, Encoder<bool>>
@@ -178,7 +192,11 @@
     {
         *m_Iterator = right;
     }
+
+    bool Get() const override
+    {
+        return *m_Iterator;
+    }
 };
 
-
-} //namespace armnn
\ No newline at end of file
+} //namespace armnn
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index 596c099..b1cdef9 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -26,6 +26,7 @@
     FullyConnected.hpp
     Gather.cpp
     Gather.hpp
+    LstmUtils.hpp
     Maximum.hpp
     Merger.hpp
     Merger.cpp
@@ -80,8 +81,8 @@
     RefGatherWorkload.hpp
     RefL2NormalizationFloat32Workload.cpp
     RefL2NormalizationFloat32Workload.hpp
-    RefLstmFloat32Workload.cpp
-    RefLstmFloat32Workload.hpp
+    RefLstmWorkload.cpp
+    RefLstmWorkload.hpp
     RefMergerFloat32Workload.cpp
     RefMergerFloat32Workload.hpp
     RefMergerUint8Workload.cpp
diff --git a/src/backends/reference/workloads/LstmUtils.hpp b/src/backends/reference/workloads/LstmUtils.hpp
new file mode 100644
index 0000000..db02a84
--- /dev/null
+++ b/src/backends/reference/workloads/LstmUtils.hpp
@@ -0,0 +1,218 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "BaseIterator.hpp"
+#include <backendsCommon/CpuTensorHandle.hpp>
+
+namespace
+{
+
+// Helper functions ported from the Android code base
+// Refer to: android/external/tensorflow/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
+
+void MatrixBatchVectorMultiplyAccumulate(armnn::Decoder<float>& matrix,
+                                         uint32_t mRows,
+                                         uint32_t mCols,
+                                         armnn::Decoder<float>& vector,
+                                         uint32_t nBatch,
+                                         armnn::Encoder<float>& outResult)
+{
+    for (uint32_t b = 0; b < nBatch; b++)
+    {
+        for (uint32_t r = 0; r < mRows; r++)
+        {
+            vector += b * mCols;
+            for (uint32_t c = 0; c < mCols; c++)
+            {
+                outResult.Set(outResult.Get() + matrix.Get() * vector.Get());
+                ++matrix;
+                ++vector;
+            }
+            outResult += 1;
+            vector -= (b+1) * mCols;
+        }
+        matrix -= (mRows * mCols);
+    }
+    outResult -= (mRows * nBatch);
+}
+
+void VectorBatchVectorAssign(armnn::Decoder<float>& vector,
+                             uint32_t vSize,
+                             uint32_t nBatch,
+                             armnn::Encoder<float>& outBatchVector)
+{
+    for (uint32_t b = 0; b < nBatch; b++)
+    {
+        for (uint32_t v = 0; v < vSize; v++)
+        {
+            outBatchVector.Set(vector.Get());
+            ++outBatchVector;
+            ++vector;
+        }
+        vector -= vSize;
+    }
+    outBatchVector -= (nBatch * vSize);
+}
+
+void VectorBatchVectorCwiseProductAccumulate(armnn::Decoder<float>& vector,
+                                             uint32_t vSize,
+                                             armnn::Decoder<float>& batchVector,
+                                             uint32_t nBatch,
+                                             armnn::Encoder<float>& outResult)
+{
+    for (uint32_t b = 0; b < nBatch; b++)
+    {
+        for (uint32_t v = 0; v < vSize; v++)
+        {
+            outResult.Set(outResult.Get() + vector.Get() * batchVector.Get());
+            ++outResult;
+            ++vector;
+            ++batchVector;
+        }
+        vector -= vSize;
+    }
+    batchVector -= vSize * nBatch;
+    outResult -= vSize * nBatch;
+}
+
+void Sub1Vector(armnn::Decoder<float>& vector,
+                uint32_t vSize,
+                armnn::Encoder<float>& result)
+{
+    for (uint32_t v = 0; v < vSize; v++)
+    {
+        result.Set(1.0f - vector.Get());
+        ++vector;
+        ++result;
+    }
+    vector -= vSize;
+    result -= vSize;
+}
+
+void VectorVectorCwiseProduct(armnn::Decoder<float>& vector1,
+                              armnn::Decoder<float>& vector2,
+                              uint32_t vSize,
+                              armnn::Encoder<float>& outResult)
+{
+    for (uint32_t v = 0; v < vSize; v++)
+    {
+        outResult.Set(vector1.Get() * vector2.Get());
+        ++outResult;
+        ++vector1;
+        ++vector2;
+    }
+    outResult -= vSize;
+    vector1 -= vSize;
+    vector2 -= vSize;
+}
+
+void VectorVectorCwiseProductAccumulate(armnn::Decoder<float>& vector1,
+                                        armnn::Decoder<float>& vector2,
+                                        uint32_t vSize,
+                                        armnn::Encoder<float>& outResult)
+{
+    for (uint32_t v = 0; v < vSize; v++)
+    {
+        outResult.Set(outResult.Get() + vector1.Get() * vector2.Get());
+        ++outResult;
+        ++vector1;
+        ++vector2;
+    }
+    outResult -= vSize;
+    vector1 -= vSize;
+    vector2 -= vSize;
+}
+
+float Clip(float f,
+           float absLimit)
+{
+    float result = (absLimit < f) ? absLimit : f;
+    result = (-absLimit > result) ? -absLimit : result;
+    return result;
+}
+
+void ClipVector(armnn::Decoder<float>& vector,
+                uint32_t vSize,
+                float absLimit,
+                armnn::Encoder<float>& outResult)
+{
+    for (uint32_t v = 0; v < vSize; v++)
+    {
+        outResult.Set(Clip(vector.Get(), absLimit));
+        ++vector;
+        ++outResult;
+    }
+    vector -= vSize;
+    outResult -= vSize;
+}
+
+void CopyVector(armnn::Decoder<float>& vector,
+                uint32_t vSize,
+                armnn::Encoder<float>& outResult)
+{
+    for (uint32_t v = 0; v < vSize; v++)
+    {
+        outResult.Set(vector.Get());
+        ++outResult;
+        ++vector;
+    }
+    outResult -= vSize;
+    vector -= vSize;
+}
+
+void SetActivationParameters(uint32_t activation,
+                             armnn::ActivationFunction& outArmnnActivation,
+                             float& outA,
+                             float& outB)
+{
+    switch (activation)
+    {
+    case 0: // None
+        outA = 0;
+        outB = 0;
+        return;
+
+    case 1: // Relu
+        outArmnnActivation = armnn::ActivationFunction::ReLu;
+        outA = 0;
+        outB = 0;
+        return;
+
+    case 3: // Relu6
+        outArmnnActivation = armnn::ActivationFunction::BoundedReLu;
+        outA = 6;
+        outB = 0;
+        return;
+
+    case 4: // Tanh
+        outArmnnActivation = armnn::ActivationFunction::TanH;
+        outA = 1;
+        outB = 1;
+        return;
+
+    case 6: // Sigmoid
+        outArmnnActivation = armnn::ActivationFunction::Sigmoid;
+        outA = 0;
+        outB = 0;
+        return;
+
+    default:
+        throw armnn::Exception("Unsupported activation function: " + std::to_string(activation));
+    }
+}
+
+std::unique_ptr<armnn::ScopedCpuTensorHandle> AssignScopedCpuTensorHandle(const armnn::ConstCpuTensorHandle* ptr)
+{
+    if (!ptr)
+    {
+        return nullptr;
+    }
+
+    return std::make_unique<armnn::ScopedCpuTensorHandle>(*ptr);
+}
+
+} // anonymous namespace
diff --git a/src/backends/reference/workloads/RefLstmFloat32Workload.cpp b/src/backends/reference/workloads/RefLstmFloat32Workload.cpp
deleted file mode 100644
index c697b66..0000000
--- a/src/backends/reference/workloads/RefLstmFloat32Workload.cpp
+++ /dev/null
@@ -1,379 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "RefLstmFloat32Workload.hpp"
-#include "RefWorkloadUtils.hpp"
-#include "Activation.hpp"
-
-namespace
-{
-
-// Helper functions ported from the Android code base
-// Refer to: android/external/tensorflow/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
-
-void MatrixBatchVectorMultiplyAccumulate(const float* matrix,
-                                         uint32_t mRows,
-                                         uint32_t mCols,
-                                         const float* vector,
-                                         uint32_t nBatch,
-                                         float* outResult,
-                                         int resultStride = 1)
-{
-    float* resultInBatch = outResult;
-    for (uint32_t b = 0; b < nBatch; b++)
-    {
-        const float* matrixPtr = matrix;
-        for (uint32_t r = 0; r < mRows; r++)
-        {
-            const float* vectorInBatch = vector + b * mCols;
-            for (uint32_t c = 0; c < mCols; c++)
-            {
-                *resultInBatch += *matrixPtr++ * *vectorInBatch++;
-            }
-            resultInBatch += resultStride;
-        }
-    }
-}
-
-void VectorBatchVectorAssign(const float* vector,
-                             uint32_t vSize,
-                             uint32_t nBatch,
-                             float* outBatchVector)
-{
-    for (uint32_t b = 0; b < nBatch; b++)
-    {
-        memcpy(outBatchVector + b * vSize, vector, vSize * sizeof(float));
-    }
-}
-
-void VectorBatchVectorCwiseProductAccumulate(const float* vector,
-                                             uint32_t vSize,
-                                             const float* batchVector,
-                                             uint32_t nBatch,
-                                             float* outResult)
-{
-    for (uint32_t b = 0; b < nBatch; b++)
-    {
-        for (uint32_t v = 0; v < vSize; v++)
-        {
-            *outResult++ += vector[v] * *batchVector++;
-        }
-    }
-}
-
-void Sub1Vector(const float* vector,
-                uint32_t vSize,
-                float* result)
-{
-    for (uint32_t v = 0; v < vSize; v++)
-    {
-        *result++ = 1.0f - *vector++;
-    }
-}
-
-void VectorVectorCwiseProduct(const float* vector1,
-                              const float* vector2,
-                              uint32_t vSize,
-                              float* outResult)
-{
-    for (uint32_t v = 0; v < vSize; v++)
-    {
-        *outResult++ = *vector1++ * *vector2++;
-    }
-}
-
-void VectorVectorCwiseProductAccumulate(const float* vector1,
-                                        const float* vector2,
-                                        uint32_t vSize,
-                                        float* outResult)
-{
-    for (uint32_t v = 0; v < vSize; v++)
-    {
-        *outResult++ += *vector1++ * *vector2++;
-    }
-}
-
-float Clip(float f,
-           float absLimit)
-{
-    float result = (absLimit < f) ? absLimit : f;
-    result = (-absLimit > result) ? -absLimit : result;
-    return result;
-}
-
-void ClipVector(const float* vector,
-                uint32_t vSize,
-                float absLimit,
-                float* outResult)
-{
-    for (uint32_t v = 0; v < vSize; v++)
-    {
-        *outResult++ = Clip(*vector++, absLimit);
-    }
-}
-
-void CopyVector(const float* vector,
-                uint32_t vSize,
-                float* outResult)
-{
-    memcpy(outResult, vector, vSize * sizeof(float));
-}
-
-void SetActivationParameters(uint32_t activation,
-                             armnn::ActivationFunction& outArmnnActivation,
-                             float& outA,
-                             float& outB)
-{
-    switch (activation)
-    {
-    case 0: // None
-        outA = 0;
-        outB = 0;
-        return;
-
-    case 1: // Relu
-        outArmnnActivation = armnn::ActivationFunction::ReLu;
-        outA = 0;
-        outB = 0;
-        return;
-
-    case 3: // Relu6
-        outArmnnActivation = armnn::ActivationFunction::BoundedReLu;
-        outA = 6;
-        outB = 0;
-        return;
-
-    case 4: // Tanh
-        outArmnnActivation = armnn::ActivationFunction::TanH;
-        outA = 1;
-        outB = 1;
-        return;
-
-    case 6: // Sigmoid
-        outArmnnActivation = armnn::ActivationFunction::Sigmoid;
-        outA = 0;
-        outB = 0;
-        return;
-
-    default:
-        throw armnn::Exception("Unsupported activation function: " + std::to_string(activation));
-    }
-}
-
-std::unique_ptr<armnn::ScopedCpuTensorHandle> AssignScopedCpuTensorHandle(const armnn::ConstCpuTensorHandle* ptr)
-{
-    if (!ptr)
-    {
-        return nullptr;
-    }
-
-    return std::make_unique<armnn::ScopedCpuTensorHandle>(*ptr);
-}
-
-} // anonymous namespace
-
-namespace armnn
-{
-
-RefLstmFloat32Workload::RefLstmFloat32Workload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info)
-    : Float32Workload<LstmQueueDescriptor>(descriptor, info)
-    , m_InputToInputWeightsTensor     (AssignScopedCpuTensorHandle(descriptor.m_InputToInputWeights))
-    , m_InputToForgetWeightsTensor    (AssignScopedCpuTensorHandle(descriptor.m_InputToForgetWeights))
-    , m_InputToCellWeightsTensor      (AssignScopedCpuTensorHandle(descriptor.m_InputToCellWeights))
-    , m_InputToOutputWeightsTensor    (AssignScopedCpuTensorHandle(descriptor.m_InputToOutputWeights))
-    , m_RecurrentToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToInputWeights))
-    , m_RecurrentToForgetWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToForgetWeights))
-    , m_RecurrentToCellWeightsTensor  (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToCellWeights))
-    , m_RecurrentToOutputWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToOutputWeights))
-    , m_CellToInputWeightsTensor      (AssignScopedCpuTensorHandle(descriptor.m_CellToInputWeights))
-    , m_CellToForgetWeightsTensor     (AssignScopedCpuTensorHandle(descriptor.m_CellToForgetWeights))
-    , m_CellToOutputWeightsTensor     (AssignScopedCpuTensorHandle(descriptor.m_CellToOutputWeights))
-    , m_InputGateBiasTensor           (AssignScopedCpuTensorHandle(descriptor.m_InputGateBias))
-    , m_ForgetGateBiasTensor          (AssignScopedCpuTensorHandle(descriptor.m_ForgetGateBias))
-    , m_CellBiasTensor                (AssignScopedCpuTensorHandle(descriptor.m_CellBias))
-    , m_OutputGateBiasTensor          (AssignScopedCpuTensorHandle(descriptor.m_OutputGateBias))
-    , m_ProjectionWeightsTensor       (AssignScopedCpuTensorHandle(descriptor.m_ProjectionWeights))
-    , m_ProjectionBiasTensor          (AssignScopedCpuTensorHandle(descriptor.m_ProjectionBias))
-{}
-
-void RefLstmFloat32Workload::Execute() const
-{
-    // This is a porting of the LSTM::Eval() method in the Android code base
-    // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp
-
-    const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
-    const TensorShape& inputShape = inputInfo.GetShape();
-
-    float* scratchBuffer  = GetOutputTensorDataFloat(0, m_Data);
-    float* outputStateOut = GetOutputTensorDataFloat(1, m_Data);
-    float* cellStateOut   = GetOutputTensorDataFloat(2, m_Data);
-    float* output         = GetOutputTensorDataFloat(3, m_Data);
-
-    const float* inputData     = GetInputTensorDataFloat(0, m_Data);
-    const float* outputStateIn = GetInputTensorDataFloat(1, m_Data);
-    const float* cellStateIn   = GetInputTensorDataFloat(2, m_Data);
-
-    const uint32_t nBatch = inputShape[0];
-    const uint32_t nInput = inputShape[1];
-
-    const uint32_t nCell   = m_InputToOutputWeightsTensor->GetShape()[0];
-    const uint32_t nOutput = m_RecurrentToOutputWeightsTensor->GetShape()[1];
-
-    const bool useCifg     = m_Data.m_Parameters.m_CifgEnabled;
-    const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled;
-
-    // Index the scratch buffers pointers to the global scratch buffer.
-    float* inputGateScratch  = nullptr;
-    float* cellScratch       = nullptr;
-    float* forgetGateScratch = nullptr;
-    float* outputGateScratch = nullptr;
-
-    if (useCifg)
-    {
-        cellScratch       = scratchBuffer + 0 * nCell * nBatch;
-        forgetGateScratch = scratchBuffer + 1 * nCell * nBatch;
-        outputGateScratch = scratchBuffer + 2 * nCell * nBatch;
-    }
-    else
-    {
-        inputGateScratch  = scratchBuffer + 0 * nCell * nBatch;
-        cellScratch       = scratchBuffer + 1 * nCell * nBatch;
-        forgetGateScratch = scratchBuffer + 2 * nCell * nBatch;
-        outputGateScratch = scratchBuffer + 3 * nCell * nBatch;
-    }
-
-    // Initialize scratch buffers with bias.
-    if (!useCifg)
-    {
-        VectorBatchVectorAssign(m_InputGateBiasTensor->GetTensor<float>(),
-                                nCell, nBatch, inputGateScratch);
-    }
-    VectorBatchVectorAssign(m_ForgetGateBiasTensor->GetTensor<float>(),
-                            nCell, nBatch, forgetGateScratch);
-    VectorBatchVectorAssign(m_CellBiasTensor->GetTensor<float>(),
-                            nCell, nBatch, cellScratch);
-    VectorBatchVectorAssign(m_OutputGateBiasTensor->GetTensor<float>(),
-                            nCell, nBatch, outputGateScratch);
-
-    // For each batch and cell: compute input_weight * input.
-    if (!useCifg)
-    {
-        MatrixBatchVectorMultiplyAccumulate(m_InputToInputWeightsTensor->GetTensor<float>(),
-                                            nCell, nInput, inputData, nBatch, inputGateScratch);
-    }
-    MatrixBatchVectorMultiplyAccumulate(m_InputToForgetWeightsTensor->GetTensor<float>(),
-                                        nCell, nInput, inputData, nBatch, forgetGateScratch);
-    MatrixBatchVectorMultiplyAccumulate(m_InputToCellWeightsTensor->GetTensor<float>(),
-                                        nCell, nInput, inputData, nBatch, cellScratch);
-    MatrixBatchVectorMultiplyAccumulate(m_InputToOutputWeightsTensor->GetTensor<float>(),
-                                        nCell, nInput, inputData, nBatch, outputGateScratch);
-
-    // For each batch and cell: compute recurrent_weight * output_state.
-    if (!useCifg)
-    {
-        MatrixBatchVectorMultiplyAccumulate(m_RecurrentToInputWeightsTensor->GetTensor<float>(),
-                                            nCell, nOutput, outputStateIn, nBatch, inputGateScratch);
-    }
-    MatrixBatchVectorMultiplyAccumulate(m_RecurrentToForgetWeightsTensor->GetTensor<float>(),
-                                        nCell, nOutput, outputStateIn, nBatch, forgetGateScratch);
-    MatrixBatchVectorMultiplyAccumulate(m_RecurrentToCellWeightsTensor->GetTensor<float>(),
-                                        nCell, nOutput, outputStateIn, nBatch, cellScratch);
-    MatrixBatchVectorMultiplyAccumulate(m_RecurrentToOutputWeightsTensor->GetTensor<float>(),
-                                        nCell, nOutput, outputStateIn, nBatch, outputGateScratch);
-
-    // For each batch and cell: update input gate.
-    if (!useCifg)
-    {
-        if (usePeephole)
-        {
-            VectorBatchVectorCwiseProductAccumulate(m_CellToInputWeightsTensor->GetTensor<float>(),
-                                                    nCell, cellStateIn, nBatch, inputGateScratch);
-        }
-        Activation(inputGateScratch, inputGateScratch,
-                   TensorInfo({nCell, nBatch}, DataType::Float32),
-                   ActivationFunction::Sigmoid, 0, 0);
-    }
-
-    // For each batch and cell: update forget gate.
-    if (usePeephole)
-    {
-        VectorBatchVectorCwiseProductAccumulate(m_CellToForgetWeightsTensor->GetTensor<float>(), nCell,
-                                                cellStateIn, nBatch, forgetGateScratch);
-    }
-    Activation(forgetGateScratch, forgetGateScratch,
-               TensorInfo({nCell, nBatch}, DataType::Float32),
-               ActivationFunction::Sigmoid, 0, 0);
-
-    // For each batch and cell: update the cell.
-    VectorVectorCwiseProduct(forgetGateScratch, cellStateIn, nBatch * nCell, cellStateOut);
-
-    ActivationFunction armnnActivationFunc = ActivationFunction::Sigmoid;
-    float a = 0;
-    float b = 0;
-    SetActivationParameters(m_Data.m_Parameters.m_ActivationFunc, armnnActivationFunc, a, b);
-
-    if (m_Data.m_Parameters.m_ActivationFunc > 0)
-    {
-        Activation(cellScratch, cellScratch,
-                   TensorInfo({nCell, nBatch}, DataType::Float32),
-                   armnnActivationFunc, a, b);
-    }
-    if (useCifg)
-    {
-        Sub1Vector(forgetGateScratch, nBatch * nCell, forgetGateScratch);
-        VectorVectorCwiseProductAccumulate(cellScratch, forgetGateScratch, nBatch * nCell, cellStateOut);
-    }
-    else
-    {
-        VectorVectorCwiseProductAccumulate(cellScratch, inputGateScratch, nBatch * nCell, cellStateOut);
-    }
-    if (m_Data.m_Parameters.m_ClippingThresCell > 0.0)
-    {
-        ClipVector(cellStateOut, nBatch * nCell, m_Data.m_Parameters.m_ClippingThresCell, cellStateOut);
-    }
-
-    // For each batch and cell: update the output gate.
-    if (usePeephole)
-    {
-        VectorBatchVectorCwiseProductAccumulate(m_CellToOutputWeightsTensor->GetTensor<float>(),
-                                                nCell, cellStateOut, nBatch, outputGateScratch);
-    }
-    Activation(outputGateScratch, outputGateScratch,
-               TensorInfo({nCell, nBatch}, DataType::Float32),
-               ActivationFunction::Sigmoid, 0, 0);
-
-    if (m_Data.m_Parameters.m_ActivationFunc > 0)
-    {
-        Activation(cellStateOut, cellScratch,
-                   TensorInfo({nCell, nBatch}, DataType::Float32),
-                   armnnActivationFunc, a, b);
-    }
-    VectorVectorCwiseProduct(outputGateScratch, cellScratch, nBatch * nCell, outputGateScratch);
-
-    // For each batch: update the projection and output_state.
-    if (m_Data.m_Parameters.m_ProjectionEnabled)
-    {
-        if (m_ProjectionBiasTensor)
-        {
-            VectorBatchVectorAssign(m_ProjectionBiasTensor->GetTensor<float>(),
-                                    nOutput, nBatch, output);
-        }
-        MatrixBatchVectorMultiplyAccumulate(m_ProjectionWeightsTensor->GetTensor<float>(),
-                                            nOutput, nCell, outputGateScratch, nBatch, output);
-
-        if (m_Data.m_Parameters.m_ClippingThresProj > 0.0)
-        {
-            ClipVector(output, nBatch * nOutput, m_Data.m_Parameters.m_ClippingThresProj, output);
-        }
-    }
-    else
-    {
-        CopyVector(outputGateScratch, nBatch * nOutput, output);
-    }
-
-    CopyVector(output, nBatch * nOutput, outputStateOut);
-}
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefLstmWorkload.cpp b/src/backends/reference/workloads/RefLstmWorkload.cpp
new file mode 100644
index 0000000..f8ebc58
--- /dev/null
+++ b/src/backends/reference/workloads/RefLstmWorkload.cpp
@@ -0,0 +1,307 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefLstmWorkload.hpp"
+#include "Activation.hpp"
+#include "Encoders.hpp"
+#include "Decoders.hpp"
+#include "LstmUtils.hpp"
+#include "RefWorkloadUtils.hpp"
+
+namespace armnn
+{
+
+RefLstmWorkload::RefLstmWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info)
+    : BaseWorkload<LstmQueueDescriptor>(descriptor, info)
+    , m_InputToInputWeightsTensor     (AssignScopedCpuTensorHandle(descriptor.m_InputToInputWeights))
+    , m_InputToForgetWeightsTensor    (AssignScopedCpuTensorHandle(descriptor.m_InputToForgetWeights))
+    , m_InputToCellWeightsTensor      (AssignScopedCpuTensorHandle(descriptor.m_InputToCellWeights))
+    , m_InputToOutputWeightsTensor    (AssignScopedCpuTensorHandle(descriptor.m_InputToOutputWeights))
+    , m_RecurrentToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToInputWeights))
+    , m_RecurrentToForgetWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToForgetWeights))
+    , m_RecurrentToCellWeightsTensor  (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToCellWeights))
+    , m_RecurrentToOutputWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToOutputWeights))
+    , m_CellToInputWeightsTensor      (AssignScopedCpuTensorHandle(descriptor.m_CellToInputWeights))
+    , m_CellToForgetWeightsTensor     (AssignScopedCpuTensorHandle(descriptor.m_CellToForgetWeights))
+    , m_CellToOutputWeightsTensor     (AssignScopedCpuTensorHandle(descriptor.m_CellToOutputWeights))
+    , m_InputGateBiasTensor           (AssignScopedCpuTensorHandle(descriptor.m_InputGateBias))
+    , m_ForgetGateBiasTensor          (AssignScopedCpuTensorHandle(descriptor.m_ForgetGateBias))
+    , m_CellBiasTensor                (AssignScopedCpuTensorHandle(descriptor.m_CellBias))
+    , m_OutputGateBiasTensor          (AssignScopedCpuTensorHandle(descriptor.m_OutputGateBias))
+    , m_ProjectionWeightsTensor       (AssignScopedCpuTensorHandle(descriptor.m_ProjectionWeights))
+    , m_ProjectionBiasTensor          (AssignScopedCpuTensorHandle(descriptor.m_ProjectionBias))
+{}
+
+void RefLstmWorkload::Execute() const
+{
+    // This is a porting of the LSTM::Eval() method in the Android code base
+    // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp
+
+    const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
+    const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+
+    const TensorShape& inputShape = inputInfo.GetShape();
+    const DataType& outputType = outputInfo.GetDataType();
+
+    std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[1]->Map());
+    std::unique_ptr<Encoder<float>> cellStateOut   = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[2]->Map());
+    std::unique_ptr<Encoder<float>> output         = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[3]->Map());
+
+    std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(outputInfo, m_Data.m_Outputs[2]->Map());
+    std::unique_ptr<Decoder<float>> outputDecoder       = MakeDecoder<float>(outputInfo, m_Data.m_Outputs[3]->Map());
+
+    std::unique_ptr<Decoder<float>> inputData     = MakeDecoder<float>(inputInfo, m_Data.m_Inputs[0]->Map());
+    std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(inputInfo, m_Data.m_Inputs[1]->Map());
+    std::unique_ptr<Decoder<float>> cellStateIn   = MakeDecoder<float>(inputInfo, m_Data.m_Inputs[2]->Map());
+
+    const uint32_t nBatch = inputShape[0];
+    const uint32_t nInput = inputShape[1];
+
+    const uint32_t nCell   = m_InputToOutputWeightsTensor->GetShape()[0];
+    const uint32_t nOutput = m_RecurrentToOutputWeightsTensor->GetShape()[1];
+
+    const bool useCifg     = m_Data.m_Parameters.m_CifgEnabled;
+    const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled;
+
+    // Index the scratch buffers pointers to the global scratch buffer.
+    std::unique_ptr<Encoder<float>> inputGateScratch  = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
+    std::unique_ptr<Encoder<float>> cellScratch       = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
+    std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
+    std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
+
+    std::unique_ptr<Decoder<float>> inputGateScratchDecoder =
+        MakeDecoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
+    std::unique_ptr<Decoder<float>> cellScratchDecoder =
+        MakeDecoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
+    std::unique_ptr<Decoder<float>> forgetGateScratchDecoder =
+        MakeDecoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
+    std::unique_ptr<Decoder<float>> outputGateScratchDecoder =
+        MakeDecoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
+
+    if (useCifg)
+    {
+        *cellScratch       += (0 * nCell * nBatch);
+        *forgetGateScratch += (1 * nCell * nBatch);
+        *outputGateScratch += (2 * nCell * nBatch);
+
+        *cellScratchDecoder       += (0 * nCell * nBatch);
+        *forgetGateScratchDecoder += (1 * nCell * nBatch);
+        *outputGateScratchDecoder += (2 * nCell * nBatch);
+    }
+    else
+    {
+        *inputGateScratch  += (0 * nCell * nBatch);
+        *cellScratch       += (1 * nCell * nBatch);
+        *forgetGateScratch += (2 * nCell * nBatch);
+        *outputGateScratch += (3 * nCell * nBatch);
+
+        *inputGateScratchDecoder  += (0 * nCell * nBatch);
+        *cellScratchDecoder       += (1 * nCell * nBatch);
+        *forgetGateScratchDecoder += (2 * nCell * nBatch);
+        *outputGateScratchDecoder += (3 * nCell * nBatch);
+    }
+
+    std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
+    std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
+        m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetTensor<void>());
+    std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
+        m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetTensor<void>());
+    std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
+        m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetTensor<void>());
+
+    std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
+    std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
+        m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetTensor<void>());
+    std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
+        m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetTensor<void>());
+    std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
+        m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetTensor<void>());
+
+    std::unique_ptr<Decoder<float>> inputGateBiasTensor;
+    std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
+        m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetTensor<void>());
+    std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
+        m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetTensor<void>());
+    std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
+        m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetTensor<void>());
+
+    std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
+    std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
+    std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
+
+    std::unique_ptr<Decoder<float>> projectionWeightsTensor;
+    std::unique_ptr<Decoder<float>> projectionBiasTensor;
+
+    if (!useCifg)
+    {
+        inputToInputWeightsTensor = MakeDecoder<float>(
+            m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetTensor<void>());
+        inputGateBiasTensor = MakeDecoder<float>(
+            m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetTensor<void>());
+        recurrentToInputWeightsTensor = MakeDecoder<float>(
+            m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetTensor<void>());
+    }
+
+    if (usePeephole)
+    {
+        cellToForgetWeightsTensor = MakeDecoder<float>(
+            m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetTensor<void>());
+        cellToOutputWeightsTensor = MakeDecoder<float>(
+            m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetTensor<void>());
+    }
+
+    if (!useCifg && usePeephole)
+    {
+        cellToInputWeightsTensor = MakeDecoder<float>(
+            m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetTensor<void>());
+    }
+
+    if (m_Data.m_Parameters.m_ProjectionEnabled)
+    {
+        projectionWeightsTensor = MakeDecoder<float>(
+            m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetTensor<void>());
+        if (m_ProjectionBiasTensor)
+        {
+            projectionBiasTensor = MakeDecoder<float>(
+                m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetTensor<void>());
+        }
+    }
+
+    // Initialize scratch buffers with bias.
+    if (!useCifg)
+    {
+        VectorBatchVectorAssign(*inputGateBiasTensor,
+                                nCell, nBatch, *inputGateScratch);
+    }
+    VectorBatchVectorAssign(*forgetGateBiasTensor,
+                            nCell, nBatch, *forgetGateScratch);
+    VectorBatchVectorAssign(*cellBiasTensor,
+                            nCell, nBatch, *cellScratch);
+    VectorBatchVectorAssign(*outputGateBiasTensor,
+                            nCell, nBatch, *outputGateScratch);
+
+    // For each batch and cell: compute input_weight * input.
+    if (!useCifg)
+    {
+        MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsTensor,
+                                            nCell, nInput, *inputData, nBatch, *inputGateScratch);
+    }
+    MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsTensor,
+                                        nCell, nInput, *inputData, nBatch, *forgetGateScratch);
+    MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsTensor,
+                                        nCell, nInput, *inputData, nBatch, *cellScratch);
+    MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsTensor,
+                                        nCell, nInput, *inputData, nBatch, *outputGateScratch);
+
+    // For each batch and cell: compute recurrent_weight * output_state.
+    if (!useCifg)
+    {
+        MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsTensor,
+                                            nCell, nOutput, *outputStateIn, nBatch, *inputGateScratch);
+    }
+    MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsTensor,
+                                        nCell, nOutput, *outputStateIn, nBatch, *forgetGateScratch);
+    MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsTensor,
+                                        nCell, nOutput, *outputStateIn, nBatch, *cellScratch);
+    MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsTensor,
+                                        nCell, nOutput, *outputStateIn, nBatch, *outputGateScratch);
+
+    // For each batch and cell: update input gate.
+    if (!useCifg)
+    {
+        if (usePeephole)
+        {
+            VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsTensor,
+                                                    nCell, *cellStateIn, nBatch, *inputGateScratch);
+        }
+        Activation(*inputGateScratchDecoder, *inputGateScratch,
+                   TensorInfo({nCell, nBatch}, outputType),
+                   ActivationFunction::Sigmoid, 0, 0);
+    }
+
+    // For each batch and cell: update forget gate.
+    if (usePeephole)
+    {
+        VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsTensor, nCell,
+                                                *cellStateIn, nBatch, *forgetGateScratch);
+    }
+    Activation(*forgetGateScratchDecoder, *forgetGateScratch,
+               TensorInfo({nCell, nBatch}, outputType),
+               ActivationFunction::Sigmoid, 0, 0);
+
+    // For each batch and cell: update the cell.
+    VectorVectorCwiseProduct(*forgetGateScratchDecoder, *cellStateIn, nBatch * nCell, *cellStateOut);
+
+    ActivationFunction armnnActivationFunc = ActivationFunction::Sigmoid;
+    float a = 0;
+    float b = 0;
+    SetActivationParameters(m_Data.m_Parameters.m_ActivationFunc, armnnActivationFunc, a, b);
+
+    if (m_Data.m_Parameters.m_ActivationFunc > 0)
+    {
+        Activation(*cellScratchDecoder, *cellScratch,
+                   TensorInfo({nCell, nBatch}, outputType),
+                   armnnActivationFunc, a, b);
+    }
+    if (useCifg)
+    {
+        Sub1Vector(*forgetGateScratchDecoder, nBatch * nCell, *forgetGateScratch);
+        VectorVectorCwiseProductAccumulate(
+            *cellScratchDecoder, *forgetGateScratchDecoder, nBatch * nCell, *cellStateOut);
+    }
+    else
+    {
+        VectorVectorCwiseProductAccumulate(
+            *cellScratchDecoder, *inputGateScratchDecoder, nBatch * nCell, *cellStateOut);
+    }
+    if (m_Data.m_Parameters.m_ClippingThresCell > 0.0)
+    {
+        ClipVector(*cellStateOutDecoder, nBatch * nCell, m_Data.m_Parameters.m_ClippingThresCell, *cellStateOut);
+    }
+
+    // For each batch and cell: update the output gate.
+    if (usePeephole)
+    {
+        VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsTensor,
+                                                nCell, *cellStateOutDecoder, nBatch, *outputGateScratch);
+    }
+    Activation(*outputGateScratchDecoder, *outputGateScratch,
+               TensorInfo({nCell, nBatch}, outputType),
+               ActivationFunction::Sigmoid, 0, 0);
+
+    if (m_Data.m_Parameters.m_ActivationFunc > 0)
+    {
+        Activation(*cellStateOutDecoder, *cellScratch,
+                   TensorInfo({nCell, nBatch}, outputType),
+                   armnnActivationFunc, a, b);
+    }
+
+    VectorVectorCwiseProduct(*outputGateScratchDecoder, *cellScratchDecoder, nBatch * nCell, *outputGateScratch);
+
+    // For each batch: update the projection and output_state.
+    if (m_Data.m_Parameters.m_ProjectionEnabled)
+    {
+        if (m_ProjectionBiasTensor)
+        {
+            VectorBatchVectorAssign(*projectionBiasTensor,
+                                    nOutput, nBatch, *output);
+        }
+        MatrixBatchVectorMultiplyAccumulate(*projectionWeightsTensor,
+                                            nOutput, nCell, *outputGateScratchDecoder, nBatch, *output);
+
+        if (m_Data.m_Parameters.m_ClippingThresProj > 0.0)
+        {
+            ClipVector(*outputDecoder, nBatch * nOutput, m_Data.m_Parameters.m_ClippingThresProj, *output);
+        }
+    }
+    else
+    {
+        CopyVector(*outputGateScratchDecoder, nBatch * nOutput, *output);
+    }
+
+    CopyVector(*outputDecoder, nBatch * nOutput, *outputStateOut);
+}
+
+} //namespace armnn
diff --git a/src/backends/reference/workloads/RefLstmFloat32Workload.hpp b/src/backends/reference/workloads/RefLstmWorkload.hpp
similarity index 89%
rename from src/backends/reference/workloads/RefLstmFloat32Workload.hpp
rename to src/backends/reference/workloads/RefLstmWorkload.hpp
index a2dead8..38e3fb9 100644
--- a/src/backends/reference/workloads/RefLstmFloat32Workload.hpp
+++ b/src/backends/reference/workloads/RefLstmWorkload.hpp
@@ -13,10 +13,10 @@
 namespace armnn
 {
 
-class RefLstmFloat32Workload : public Float32Workload<LstmQueueDescriptor>
+class RefLstmWorkload : public BaseWorkload<LstmQueueDescriptor>
 {
 public:
-    explicit RefLstmFloat32Workload(const LstmQueueDescriptor& descriptor, const WorkloadInfo& info);
+    explicit RefLstmWorkload(const LstmQueueDescriptor& descriptor, const WorkloadInfo& info);
 
     virtual void Execute() const override;
 
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index 7871a1b..8ffd348 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -51,7 +51,7 @@
 #include "Pooling2d.hpp"
 #include "RefFakeQuantizationFloat32Workload.hpp"
 #include "RefPermuteWorkload.hpp"
-#include "RefLstmFloat32Workload.hpp"
+#include "RefLstmWorkload.hpp"
 #include "RefConvertFp16ToFp32Workload.hpp"
 #include "RefConvertFp32ToFp16Workload.hpp"
 #include "RefMeanUint8Workload.hpp"