IVGCVSW-3397 Join lstm parameter infos in a struct for isLstmSupported

!android-nn-driver:1461

Change-Id: I9d8fe7adf13832ed0cbcfe98b2353c2f37011d22
Signed-off-by: Jan Eilers <jan.eilers@arm.com>
diff --git a/include/armnn/ILayerSupport.hpp b/include/armnn/ILayerSupport.hpp
index 58722fe..53dd29d 100644
--- a/include/armnn/ILayerSupport.hpp
+++ b/include/armnn/ILayerSupport.hpp
@@ -7,6 +7,7 @@
 #include <armnn/Deprecated.hpp>
 #include <armnn/DescriptorsFwd.hpp>
 #include <armnn/Optional.hpp>
+#include <armnn/LstmParams.hpp>
 
 #include <cctype>
 #include <functional>
@@ -153,28 +154,8 @@
                                  const TensorInfo& cellStateOut,
                                  const TensorInfo& output,
                                  const LstmDescriptor& descriptor,
-                                 const TensorInfo& inputToForgetWeights,
-                                 const TensorInfo& inputToCellWeights,
-                                 const TensorInfo& inputToOutputWeights,
-                                 const TensorInfo& recurrentToForgetWeights,
-                                 const TensorInfo& recurrentToCellWeights,
-                                 const TensorInfo& recurrentToOutputWeights,
-                                 const TensorInfo& forgetGateBias,
-                                 const TensorInfo& cellBias,
-                                 const TensorInfo& outputGateBias,
-                                 const TensorInfo* inputToInputWeights,
-                                 const TensorInfo* recurrentToInputWeights,
-                                 const TensorInfo* cellToInputWeights,
-                                 const TensorInfo* inputGateBias,
-                                 const TensorInfo* projectionWeights,
-                                 const TensorInfo* projectionBias,
-                                 const TensorInfo* cellToForgetWeights,
-                                 const TensorInfo* cellToOutputWeights,
-                                 Optional<std::string&> reasonIfUnsupported = EmptyOptional(),
-                                 const TensorInfo* inputLayerNormWeights = nullptr,
-                                 const TensorInfo* forgetLayerNormWeights = nullptr,
-                                 const TensorInfo* cellLayerNormWeights = nullptr,
-                                 const TensorInfo* outputLayerNormWeights = nullptr) const = 0;
+                                 const LstmInputParamsInfo& paramsInfo,
+                                 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
 
     virtual bool IsMaximumSupported(const TensorInfo& input0,
                                     const TensorInfo& input1,
diff --git a/include/armnn/LayerSupport.hpp b/include/armnn/LayerSupport.hpp
index 35336ed..65f9d08 100644
--- a/include/armnn/LayerSupport.hpp
+++ b/include/armnn/LayerSupport.hpp
@@ -9,6 +9,7 @@
 #include <armnn/Optional.hpp>
 #include <armnn/Tensor.hpp>
 #include <armnn/Types.hpp>
+#include "LstmParams.hpp"
 
 namespace armnn
 {
@@ -178,15 +179,7 @@
                      const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
                      const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
                      const TensorInfo& output, const LstmDescriptor& descriptor,
-                     const TensorInfo& inputToForgetWeights, const TensorInfo& inputToCellWeights,
-                     const TensorInfo& inputToOutputWeights, const TensorInfo& recurrentToForgetWeights,
-                     const TensorInfo& recurrentToCellWeights, const TensorInfo& recurrentToOutputWeights,
-                     const TensorInfo& forgetGateBias, const TensorInfo& cellBias,
-                     const TensorInfo& outputGateBias, const TensorInfo* inputToInputWeights,
-                     const TensorInfo* recurrentToInputWeights, const TensorInfo* cellToInputWeights,
-                     const TensorInfo* inputGateBias, const TensorInfo* projectionWeights,
-                     const TensorInfo* projectionBias, const TensorInfo* cellToForgetWeights,
-                     const TensorInfo* cellToOutputWeights, char* reasonIfUnsupported = nullptr,
+                     const LstmInputParamsInfo& paramsInfo, char* reasonIfUnsupported = nullptr,
                      size_t reasonIfUnsupportedMaxLength = 1024);
 
 /// Deprecated in favor of IBackend and ILayerSupport interfaces
diff --git a/include/armnn/LstmParams.hpp b/include/armnn/LstmParams.hpp
index a7c57c7..0c8e66d 100644
--- a/include/armnn/LstmParams.hpp
+++ b/include/armnn/LstmParams.hpp
@@ -5,6 +5,7 @@
 #pragma once
 
 #include "TensorFwd.hpp"
+#include "Exceptions.hpp"
 
 namespace armnn
 {
@@ -59,5 +60,149 @@
     const ConstTensor* m_OutputLayerNormWeights;
 };
 
+struct LstmInputParamsInfo
+{
+    LstmInputParamsInfo()
+            : m_InputToInputWeights(nullptr)
+            , m_InputToForgetWeights(nullptr)
+            , m_InputToCellWeights(nullptr)
+            , m_InputToOutputWeights(nullptr)
+            , m_RecurrentToInputWeights(nullptr)
+            , m_RecurrentToForgetWeights(nullptr)
+            , m_RecurrentToCellWeights(nullptr)
+            , m_RecurrentToOutputWeights(nullptr)
+            , m_CellToInputWeights(nullptr)
+            , m_CellToForgetWeights(nullptr)
+            , m_CellToOutputWeights(nullptr)
+            , m_InputGateBias(nullptr)
+            , m_ForgetGateBias(nullptr)
+            , m_CellBias(nullptr)
+            , m_OutputGateBias(nullptr)
+            , m_ProjectionWeights(nullptr)
+            , m_ProjectionBias(nullptr)
+            , m_InputLayerNormWeights(nullptr)
+            , m_ForgetLayerNormWeights(nullptr)
+            , m_CellLayerNormWeights(nullptr)
+            , m_OutputLayerNormWeights(nullptr)
+    {
+    }
+    const TensorInfo* m_InputToInputWeights;
+    const TensorInfo* m_InputToForgetWeights;
+    const TensorInfo* m_InputToCellWeights;
+    const TensorInfo* m_InputToOutputWeights;
+    const TensorInfo* m_RecurrentToInputWeights;
+    const TensorInfo* m_RecurrentToForgetWeights;
+    const TensorInfo* m_RecurrentToCellWeights;
+    const TensorInfo* m_RecurrentToOutputWeights;
+    const TensorInfo* m_CellToInputWeights;
+    const TensorInfo* m_CellToForgetWeights;
+    const TensorInfo* m_CellToOutputWeights;
+    const TensorInfo* m_InputGateBias;
+    const TensorInfo* m_ForgetGateBias;
+    const TensorInfo* m_CellBias;
+    const TensorInfo* m_OutputGateBias;
+    const TensorInfo* m_ProjectionWeights;
+    const TensorInfo* m_ProjectionBias;
+    const TensorInfo* m_InputLayerNormWeights;
+    const TensorInfo* m_ForgetLayerNormWeights;
+    const TensorInfo* m_CellLayerNormWeights;
+    const TensorInfo* m_OutputLayerNormWeights;
+
+    const TensorInfo& deref(const TensorInfo* tensorInfo) const
+    {
+        if (tensorInfo != nullptr)
+        {
+            const TensorInfo &temp = *tensorInfo;
+            return temp;
+        }
+        throw InvalidArgumentException("Can't dereference a null pointer");
+    }
+
+    const TensorInfo& get_InputToInputWeights() const
+    {
+        return deref(m_InputToInputWeights);
+    }
+    const TensorInfo& get_InputToForgetWeights() const
+    {
+        return deref(m_InputToForgetWeights);
+    }
+    const TensorInfo& get_InputToCellWeights() const
+    {
+        return deref(m_InputToCellWeights);
+    }
+    const TensorInfo& get_InputToOutputWeights() const
+    {
+        return deref(m_InputToOutputWeights);
+    }
+    const TensorInfo& get_RecurrentToInputWeights() const
+    {
+        return deref(m_RecurrentToInputWeights);
+    }
+    const TensorInfo& get_RecurrentToForgetWeights() const
+    {
+        return deref(m_RecurrentToForgetWeights);
+    }
+    const TensorInfo& get_RecurrentToCellWeights() const
+    {
+        return deref(m_RecurrentToCellWeights);
+    }
+    const TensorInfo& get_RecurrentToOutputWeights() const
+    {
+        return deref(m_RecurrentToOutputWeights);
+    }
+    const TensorInfo& get_CellToInputWeights() const
+    {
+        return deref(m_CellToInputWeights);
+    }
+    const TensorInfo& get_CellToForgetWeights() const
+    {
+        return deref(m_CellToForgetWeights);
+    }
+    const TensorInfo& get_CellToOutputWeights() const
+    {
+        return deref(m_CellToOutputWeights);
+    }
+    const TensorInfo& get_InputGateBias() const
+    {
+        return deref(m_InputGateBias);
+    }
+    const TensorInfo& get_ForgetGateBias() const
+    {
+        return deref(m_ForgetGateBias);
+    }
+    const TensorInfo& get_CellBias() const
+    {
+        return deref(m_CellBias);
+    }
+    const TensorInfo& get_OutputGateBias() const
+    {
+        return deref(m_OutputGateBias);
+    }
+    const TensorInfo& get_ProjectionWeights() const
+    {
+        return deref(m_ProjectionWeights);
+    }
+    const TensorInfo& get_ProjectionBias() const
+    {
+        return deref(m_ProjectionBias);
+    }
+    const TensorInfo& get_InputLayerNormWeights() const
+    {
+        return deref(m_InputLayerNormWeights);
+    }
+    const TensorInfo& get_ForgetLayerNormWeights() const
+    {
+        return deref(m_ForgetLayerNormWeights);
+    }
+    const TensorInfo& get_CellLayerNormWeights() const
+    {
+        return deref(m_CellLayerNormWeights);
+    }
+    const TensorInfo& get_OutputLayerNormWeights() const
+    {
+        return deref(m_OutputLayerNormWeights);
+    }
+};
+
 } // namespace armnn
 
diff --git a/src/armnn/LayerSupport.cpp b/src/armnn/LayerSupport.cpp
index b2ca85c..a2908aa 100644
--- a/src/armnn/LayerSupport.cpp
+++ b/src/armnn/LayerSupport.cpp
@@ -333,27 +333,13 @@
                      const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
                      const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
                      const TensorInfo& output, const LstmDescriptor& descriptor,
-                     const TensorInfo& inputToForgetWeights, const TensorInfo& inputToCellWeights,
-                     const TensorInfo& inputToOutputWeights, const TensorInfo& recurrentToForgetWeights,
-                     const TensorInfo& recurrentToCellWeights, const TensorInfo& recurrentToOutputWeights,
-                     const TensorInfo& forgetGateBias, const TensorInfo& cellBias,
-                     const TensorInfo& outputGateBias, const TensorInfo* inputToInputWeights,
-                     const TensorInfo* recurrentToInputWeights, const TensorInfo* cellToInputWeights,
-                     const TensorInfo* inputGateBias, const TensorInfo* projectionWeights,
-                     const TensorInfo* projectionBias, const TensorInfo* cellToForgetWeights,
-                     const TensorInfo* cellToOutputWeights, char* reasonIfUnsupported,
+                     const LstmInputParamsInfo& paramsInfo, char* reasonIfUnsupported,
                      size_t reasonIfUnsupportedMaxLength)
 
 {
     FORWARD_LAYER_SUPPORT_FUNC(backend, IsLstmSupported, input, outputStateIn, cellStateIn,
                                scratchBuffer, outputStateOut, cellStateOut,
-                               output, descriptor, inputToForgetWeights, inputToCellWeights,
-                               inputToOutputWeights, recurrentToForgetWeights,
-                               recurrentToCellWeights, recurrentToOutputWeights,
-                               forgetGateBias, cellBias, outputGateBias,
-                               inputToInputWeights, recurrentToInputWeights,
-                               cellToInputWeights, inputGateBias, projectionWeights,
-                               projectionBias, cellToForgetWeights, cellToOutputWeights);
+                               output, descriptor, paramsInfo);
 }
 
 bool IsMaximumSupported(const BackendId& backend,
diff --git a/src/backends/backendsCommon/LayerSupportBase.cpp b/src/backends/backendsCommon/LayerSupportBase.cpp
index 4488e25..ea22fac 100644
--- a/src/backends/backendsCommon/LayerSupportBase.cpp
+++ b/src/backends/backendsCommon/LayerSupportBase.cpp
@@ -226,28 +226,8 @@
                                        const TensorInfo& cellStateOut,
                                        const TensorInfo& output,
                                        const LstmDescriptor& descriptor,
-                                       const TensorInfo& inputToForgetWeights,
-                                       const TensorInfo& inputToCellWeights,
-                                       const TensorInfo& inputToOutputWeights,
-                                       const TensorInfo& recurrentToForgetWeights,
-                                       const TensorInfo& recurrentToCellWeights,
-                                       const TensorInfo& recurrentToOutputWeights,
-                                       const TensorInfo& forgetGateBias,
-                                       const TensorInfo& cellBias,
-                                       const TensorInfo& outputGateBias,
-                                       const TensorInfo* inputToInputWeights,
-                                       const TensorInfo* recurrentToInputWeights,
-                                       const TensorInfo* cellToInputWeights,
-                                       const TensorInfo* inputGateBias,
-                                       const TensorInfo* projectionWeights,
-                                       const TensorInfo* projectionBias,
-                                       const TensorInfo* cellToForgetWeights,
-                                       const TensorInfo* cellToOutputWeights,
-                                       Optional<std::string&> reasonIfUnsupported,
-                                       const TensorInfo* inputLayerNormWeights,
-                                       const TensorInfo* forgetLayerNormWeights,
-                                       const TensorInfo* cellLayerNormWeights,
-                                       const TensorInfo* outputLayerNormWeights) const
+                                       const LstmInputParamsInfo& paramsInfo,
+                                       Optional<std::string&> reasonIfUnsupported) const
 {
     return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
 }
diff --git a/src/backends/backendsCommon/LayerSupportBase.hpp b/src/backends/backendsCommon/LayerSupportBase.hpp
index 03a928a..36b8e77 100644
--- a/src/backends/backendsCommon/LayerSupportBase.hpp
+++ b/src/backends/backendsCommon/LayerSupportBase.hpp
@@ -140,28 +140,8 @@
                          const TensorInfo& cellStateOut,
                          const TensorInfo& output,
                          const LstmDescriptor& descriptor,
-                         const TensorInfo& inputToForgetWeights,
-                         const TensorInfo& inputToCellWeights,
-                         const TensorInfo& inputToOutputWeights,
-                         const TensorInfo& recurrentToForgetWeights,
-                         const TensorInfo& recurrentToCellWeights,
-                         const TensorInfo& recurrentToOutputWeights,
-                         const TensorInfo& forgetGateBias,
-                         const TensorInfo& cellBias,
-                         const TensorInfo& outputGateBias,
-                         const TensorInfo* inputToInputWeights,
-                         const TensorInfo* recurrentToInputWeights,
-                         const TensorInfo* cellToInputWeights,
-                         const TensorInfo* inputGateBias,
-                         const TensorInfo* projectionWeights,
-                         const TensorInfo* projectionBias,
-                         const TensorInfo* cellToForgetWeights,
-                         const TensorInfo* cellToOutputWeights,
-                         Optional<std::string&> reasonIfUnsupported = EmptyOptional(),
-                         const TensorInfo* inputLayerNormWeights = nullptr,
-                         const TensorInfo* forgetLayerNormWeights = nullptr,
-                         const TensorInfo* cellLayerNormWeights = nullptr,
-                         const TensorInfo* outputLayerNormWeights = nullptr) const override;
+                         const LstmInputParamsInfo& paramsInfo,
+                         Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
 
     bool IsMaximumSupported(const TensorInfo& input0,
                             const TensorInfo& input1,
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index 3502c38..1c23e17 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -388,20 +388,20 @@
             const TensorInfo& outputGateBias
                     = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
 
-            // Optional parameters
-            const TensorInfo* inputToInputWeights = nullptr;
-            const TensorInfo* recurrentToInputWeights = nullptr;
-            const TensorInfo* cellToInputWeights = nullptr;
-            const TensorInfo* inputGateBias = nullptr;
-            const TensorInfo* projectionWeights = nullptr;
-            const TensorInfo* projectionBias = nullptr;
-            const TensorInfo* cellToForgetWeights = nullptr;
-            const TensorInfo* cellToOutputWeights = nullptr;
-            const TensorInfo* inputLayerNormWeights = nullptr;
-            const TensorInfo* forgetLayerNormWeights = nullptr;
-            const TensorInfo* cellLayerNormWeights = nullptr;
-            const TensorInfo* outputLayerNormWeights = nullptr;
+            LstmInputParamsInfo paramsInfo;
 
+            paramsInfo.m_InputToForgetWeights     = &inputToForgetWeights;
+            paramsInfo.m_InputToCellWeights       = &inputToCellWeights;
+            paramsInfo.m_InputToOutputWeights     = &inputToOutputWeights;
+            paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
+            paramsInfo.m_RecurrentToCellWeights   = &recurrentToCellWeights;
+            paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
+            paramsInfo.m_ForgetGateBias           = &forgetGateBias;
+            paramsInfo.m_CellBias                 = &cellBias;
+            paramsInfo.m_OutputGateBias           = &outputGateBias;
+
+
+            // Optional parameters
             TensorInfo optInputToInputWeights;
             TensorInfo optRecurrentToInputWeights;
             TensorInfo optCellToInputWeights;
@@ -419,32 +419,32 @@
             {
                 optInputToInputWeights =
                     OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
-                inputToInputWeights = &optInputToInputWeights;
+                paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
 
                 optRecurrentToInputWeights =
                     OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
-                recurrentToInputWeights = &optRecurrentToInputWeights;
+                paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
                 if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
                 {
                     optCellToInputWeights =
                         OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
-                    cellToInputWeights = &optCellToInputWeights;
+                    paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
                 }
                 optInputGateBias =
                        OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
-                inputGateBias = &optInputGateBias;
+                paramsInfo.m_InputGateBias = &optInputGateBias;
             }
 
             if(descriptor.m_ProjectionEnabled)
             {
                 optProjectionWeights =
                     OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
-                projectionWeights = &optProjectionWeights;
+                paramsInfo.m_ProjectionWeights = &optProjectionWeights;
                 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
                 {
                     optProjectionBias =
                         OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
-                    projectionBias = &optProjectionBias;
+                    paramsInfo.m_ProjectionBias = &optProjectionBias;
                 }
             }
 
@@ -452,29 +452,29 @@
             {
                 optCellToForgetWeights =
                     OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
-                cellToForgetWeights = &optCellToForgetWeights;
+                paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
                 optCellToOutputWeights =
                     OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
-                cellToOutputWeights = &optCellToOutputWeights;
+                paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
             }
 
             if(descriptor.m_LayerNormEnabled)
             {
                 optInputLayerNormWeights = OverrideDataType(
                         cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
-                inputLayerNormWeights = &optInputLayerNormWeights;
+                paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
 
                 optForgetLayerNormWeights = OverrideDataType(
                         cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
-                forgetLayerNormWeights = &optForgetLayerNormWeights;
+                paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
 
                 optCellLayerNormWeights = OverrideDataType(
                         cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
-                cellLayerNormWeights = &optCellLayerNormWeights;
+                paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
 
                 optOutputLayerNormWeights = OverrideDataType(
                         cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
-                outputLayerNormWeights = &optOutputLayerNormWeights;
+                paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
             }
 
             result = layerSupportObject->IsLstmSupported(
@@ -486,28 +486,8 @@
                                      cellStateOut,
                                      output,
                                      descriptor,
-                                     inputToForgetWeights,
-                                     inputToCellWeights,
-                                     inputToOutputWeights,
-                                     recurrentToForgetWeights,
-                                     recurrentToCellWeights,
-                                     recurrentToOutputWeights,
-                                     forgetGateBias,
-                                     cellBias,
-                                     outputGateBias,
-                                     inputToInputWeights,
-                                     recurrentToInputWeights,
-                                     cellToInputWeights,
-                                     inputGateBias,
-                                     projectionWeights,
-                                     projectionBias,
-                                     cellToForgetWeights,
-                                     cellToOutputWeights,
-                                     reason,
-                                     inputLayerNormWeights,
-                                     forgetLayerNormWeights,
-                                     cellLayerNormWeights,
-                                     outputLayerNormWeights);
+                                     paramsInfo,
+                                     reason);
             break;
         }
         case LayerType::Maximum:
diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp
index 497a643..6d9b197 100644
--- a/src/backends/cl/ClLayerSupport.cpp
+++ b/src/backends/cl/ClLayerSupport.cpp
@@ -405,28 +405,8 @@
                                      const TensorInfo& cellStateOut,
                                      const TensorInfo& output,
                                      const LstmDescriptor& descriptor,
-                                     const TensorInfo& inputToForgetWeights,
-                                     const TensorInfo& inputToCellWeights,
-                                     const TensorInfo& inputToOutputWeights,
-                                     const TensorInfo& recurrentToForgetWeights,
-                                     const TensorInfo& recurrentToCellWeights,
-                                     const TensorInfo& recurrentToOutputWeights,
-                                     const TensorInfo& forgetGateBias,
-                                     const TensorInfo& cellBias,
-                                     const TensorInfo& outputGateBias,
-                                     const TensorInfo* inputToInputWeights,
-                                     const TensorInfo* recurrentToInputWeights,
-                                     const TensorInfo* cellToInputWeights,
-                                     const TensorInfo* inputGateBias,
-                                     const TensorInfo* projectionWeights,
-                                     const TensorInfo* projectionBias,
-                                     const TensorInfo* cellToForgetWeights,
-                                     const TensorInfo* cellToOutputWeights,
-                                     Optional<std::string&> reasonIfUnsupported,
-                                     const TensorInfo* inputLayerNormWeights,
-                                     const TensorInfo* forgetLayerNormWeights,
-                                     const TensorInfo* cellLayerNormWeights,
-                                     const TensorInfo* outputLayerNormWeights) const
+                                     const LstmInputParamsInfo& paramsInfo,
+                                     Optional<std::string&> reasonIfUnsupported) const
 {
     FORWARD_WORKLOAD_VALIDATE_FUNC(ClLstmFloatWorkloadValidate,
                                    reasonIfUnsupported,
@@ -438,23 +418,7 @@
                                    cellStateOut,
                                    output,
                                    descriptor,
-                                   inputToForgetWeights,
-                                   inputToCellWeights,
-                                   inputToOutputWeights,
-                                   recurrentToForgetWeights,
-                                   recurrentToCellWeights,
-                                   recurrentToOutputWeights,
-                                   forgetGateBias,
-                                   cellBias,
-                                   outputGateBias,
-                                   inputToInputWeights,
-                                   recurrentToInputWeights,
-                                   cellToInputWeights,
-                                   inputGateBias,
-                                   projectionWeights,
-                                   projectionBias,
-                                   cellToForgetWeights,
-                                   cellToOutputWeights);
+                                   paramsInfo);
 }
 
 bool ClLayerSupport::IsMaximumSupported(const TensorInfo& input0,
diff --git a/src/backends/cl/ClLayerSupport.hpp b/src/backends/cl/ClLayerSupport.hpp
index 4a55997..63a4daf 100644
--- a/src/backends/cl/ClLayerSupport.hpp
+++ b/src/backends/cl/ClLayerSupport.hpp
@@ -114,28 +114,8 @@
                          const TensorInfo& cellStateOut,
                          const TensorInfo& output,
                          const LstmDescriptor& descriptor,
-                         const TensorInfo& inputToForgetWeights,
-                         const TensorInfo& inputToCellWeights,
-                         const TensorInfo& inputToOutputWeights,
-                         const TensorInfo& recurrentToForgetWeights,
-                         const TensorInfo& recurrentToCellWeights,
-                         const TensorInfo& recurrentToOutputWeights,
-                         const TensorInfo& forgetGateBias,
-                         const TensorInfo& cellBias,
-                         const TensorInfo& outputGateBias,
-                         const TensorInfo* inputToInputWeights,
-                         const TensorInfo* recurrentToInputWeights,
-                         const TensorInfo* cellToInputWeights,
-                         const TensorInfo* inputGateBias,
-                         const TensorInfo* projectionWeights,
-                         const TensorInfo* projectionBias,
-                         const TensorInfo* cellToForgetWeights,
-                         const TensorInfo* cellToOutputWeights,
-                         Optional<std::string&> reasonIfUnsupported = EmptyOptional(),
-                         const TensorInfo* inputLayerNormWeights = nullptr,
-                         const TensorInfo* forgetLayerNormWeights = nullptr,
-                         const TensorInfo* cellLayerNormWeights = nullptr,
-                         const TensorInfo* outputLayerNormWeights = nullptr) const override;
+                         const LstmInputParamsInfo& paramsInfo,
+                         Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
 
     bool IsMaximumSupported(const TensorInfo& input0,
                             const TensorInfo& input1,
diff --git a/src/backends/cl/workloads/ClLstmFloatWorkload.cpp b/src/backends/cl/workloads/ClLstmFloatWorkload.cpp
index f4d8974..3dbbbc3 100644
--- a/src/backends/cl/workloads/ClLstmFloatWorkload.cpp
+++ b/src/backends/cl/workloads/ClLstmFloatWorkload.cpp
@@ -224,22 +224,7 @@
                                                 const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
                                                 const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
                                                 const TensorInfo& output, const LstmDescriptor& descriptor,
-                                                const TensorInfo& inputToForgetWeights,
-                                                const TensorInfo& inputToCellWeights,
-                                                const TensorInfo& inputToOutputWeights,
-                                                const TensorInfo& recurrentToForgetWeights,
-                                                const TensorInfo& recurrentToCellWeights,
-                                                const TensorInfo& recurrentToOutputWeights,
-                                                const TensorInfo& forgetGateBias, const TensorInfo& cellBias,
-                                                const TensorInfo& outputGateBias,
-                                                const TensorInfo* inputToInputWeights,
-                                                const TensorInfo* recurrentToInputWeights,
-                                                const TensorInfo* cellToInputWeights,
-                                                const TensorInfo* inputGateBias,
-                                                const TensorInfo* projectionWeights,
-                                                const TensorInfo* projectionBias,
-                                                const TensorInfo* cellToForgetWeights,
-                                                const TensorInfo* cellToOutputWeights)
+                                                const LstmInputParamsInfo& paramsInfo)
 {
     arm_compute::LSTMParams<arm_compute::ITensorInfo> lstm_params_info;
 
@@ -253,18 +238,21 @@
     const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output);
 
     // Basic parameters
-    const arm_compute::TensorInfo aclInputToForgetWeightsInfo = BuildArmComputeTensorInfo(inputToForgetWeights);
-    const arm_compute::TensorInfo aclInputToCellWeightsInfo = BuildArmComputeTensorInfo(inputToCellWeights);
-    const arm_compute::TensorInfo aclInputToOutputWeightsInfo = BuildArmComputeTensorInfo(inputToOutputWeights);
+    const arm_compute::TensorInfo aclInputToForgetWeightsInfo
+                                  = BuildArmComputeTensorInfo(paramsInfo.get_InputToForgetWeights());
+    const arm_compute::TensorInfo aclInputToCellWeightsInfo
+                                  = BuildArmComputeTensorInfo(paramsInfo.get_InputToCellWeights());
+    const arm_compute::TensorInfo aclInputToOutputWeightsInfo
+                                  = BuildArmComputeTensorInfo(paramsInfo.get_InputToOutputWeights());
     const arm_compute::TensorInfo aclRecurrentToForgetWeightsInfo
-                                  = BuildArmComputeTensorInfo(recurrentToForgetWeights);
+                                  = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToForgetWeights());
     const arm_compute::TensorInfo aclRecurrentToCellWeightsInfo
-                                  = BuildArmComputeTensorInfo(recurrentToCellWeights);
+                                  = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToCellWeights());
     const arm_compute::TensorInfo aclRecurrentToOutputWeightsInfo
-                                  = BuildArmComputeTensorInfo(recurrentToOutputWeights);
-    const arm_compute::TensorInfo aclForgetGateBiasInfo = BuildArmComputeTensorInfo(forgetGateBias);
-    const arm_compute::TensorInfo aclCellBiasInfo = BuildArmComputeTensorInfo(cellBias);
-    const arm_compute::TensorInfo aclOutputGateBiasInfo = BuildArmComputeTensorInfo(outputGateBias);
+                                  = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToOutputWeights());
+    const arm_compute::TensorInfo aclForgetGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_ForgetGateBias());
+    const arm_compute::TensorInfo aclCellBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellBias());
+    const arm_compute::TensorInfo aclOutputGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_OutputGateBias());
 
     arm_compute::TensorInfo aclInputToInputWeightsInfo;
     arm_compute::TensorInfo aclRecurrentToInputWeightsInfo;
@@ -277,43 +265,37 @@
 
     if (!descriptor.m_CifgEnabled)
     {
-        armnn::TensorInfo inputToInputWInfo = *inputToInputWeights;
-        aclInputToInputWeightsInfo = BuildArmComputeTensorInfo(inputToInputWInfo);
-        armnn::TensorInfo recurrentToInputWInfo = *recurrentToInputWeights;
-        aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(recurrentToInputWInfo);
+        aclInputToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputToInputWeights());
+        aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToInputWeights());
 
-        if (cellToInputWeights != nullptr)
+        if (paramsInfo.m_CellToInputWeights != nullptr)
         {
-            armnn::TensorInfo cellToInputWInfo = *cellToInputWeights;
-            aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(cellToInputWInfo);
+            aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToInputWeights());
         }
-        armnn::TensorInfo inputGateBiasInfo = *inputGateBias;
-        aclInputGateBiasInfo = BuildArmComputeTensorInfo(inputGateBiasInfo);
+        aclInputGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputGateBias());
         lstm_params_info.set_cifg_params(&aclInputToInputWeightsInfo, &aclRecurrentToInputWeightsInfo,
-                                         cellToInputWeights != nullptr ? &aclCellToInputWeightsInfo: nullptr,
+                                         paramsInfo.m_CellToInputWeights != nullptr ?
+                                         &aclCellToInputWeightsInfo: nullptr,
                                          &aclInputGateBiasInfo);
     }
 
     if (descriptor.m_ProjectionEnabled)
     {
-        const armnn::TensorInfo& projectionWInfo = *projectionWeights;
-        aclProjectionWeightsInfo = BuildArmComputeTensorInfo(projectionWInfo);
+        aclProjectionWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_ProjectionWeights());
 
-        if (projectionBias != nullptr)
+        if (paramsInfo.m_ProjectionBias != nullptr)
         {
-            const armnn::TensorInfo& projectionBiasInfo = *projectionBias;
-            aclProjectionBiasInfo = BuildArmComputeTensorInfo(projectionBiasInfo);
+            aclProjectionBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputGateBias());
         }
         lstm_params_info.set_projection_params(&aclProjectionWeightsInfo,
-                                               projectionBias != nullptr ? &aclProjectionBiasInfo: nullptr);
+                                               paramsInfo.m_ProjectionBias != nullptr ?
+                                               &aclProjectionBiasInfo: nullptr);
     }
 
     if (descriptor.m_PeepholeEnabled)
     {
-        const armnn::TensorInfo& cellToForgetWInfo = *cellToForgetWeights;
-        aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(cellToForgetWInfo);
-        const armnn::TensorInfo& cellToOutputWInfo = *cellToOutputWeights;
-        aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(cellToOutputWInfo);
+        aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToForgetWeights());
+        aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToOutputWeights());
         lstm_params_info.set_peephole_params(&aclCellToForgetWeightsInfo, &aclCellToOutputWeightsInfo);
     }
 
diff --git a/src/backends/cl/workloads/ClLstmFloatWorkload.hpp b/src/backends/cl/workloads/ClLstmFloatWorkload.hpp
index 6a0c41f..9a3211a 100644
--- a/src/backends/cl/workloads/ClLstmFloatWorkload.hpp
+++ b/src/backends/cl/workloads/ClLstmFloatWorkload.hpp
@@ -49,20 +49,5 @@
                                                 const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
                                                 const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
                                                 const TensorInfo& output, const LstmDescriptor &descriptor,
-                                                const TensorInfo& inputToForgetWeights,
-                                                const TensorInfo& inputToCellWeights,
-                                                const TensorInfo& inputToOutputWeights,
-                                                const TensorInfo& recurrentToForgetWeights,
-                                                const TensorInfo& recurrentToCellWeights,
-                                                const TensorInfo& recurrentToOutputWeights,
-                                                const TensorInfo& forgetGateBias, const TensorInfo& cellBias,
-                                                const TensorInfo& outputGateBias,
-                                                const TensorInfo* inputToInputWeights,
-                                                const TensorInfo* recurrentToInputWeights,
-                                                const TensorInfo* cellToInputWeights,
-                                                const TensorInfo* inputGateBias,
-                                                const TensorInfo* projectionWeights,
-                                                const TensorInfo* projectionBias,
-                                                const TensorInfo* cellToForgetWeights,
-                                                const TensorInfo* cellToOutputWeights);
+                                                const LstmInputParamsInfo& paramsInfo);
 } //namespace armnn
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index ac7f310..59c14c4 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -924,51 +924,11 @@
                                       const TensorInfo& cellStateOut,
                                       const TensorInfo& output,
                                       const LstmDescriptor& descriptor,
-                                      const TensorInfo& inputToForgetWeights,
-                                      const TensorInfo& inputToCellWeights,
-                                      const TensorInfo& inputToOutputWeights,
-                                      const TensorInfo& recurrentToForgetWeights,
-                                      const TensorInfo& recurrentToCellWeights,
-                                      const TensorInfo& recurrentToOutputWeights,
-                                      const TensorInfo& forgetGateBias,
-                                      const TensorInfo& cellBias,
-                                      const TensorInfo& outputGateBias,
-                                      const TensorInfo* inputToInputWeights,
-                                      const TensorInfo* recurrentToInputWeights,
-                                      const TensorInfo* cellToInputWeights,
-                                      const TensorInfo* inputGateBias,
-                                      const TensorInfo* projectionWeights,
-                                      const TensorInfo* projectionBias,
-                                      const TensorInfo* cellToForgetWeights,
-                                      const TensorInfo* cellToOutputWeights,
-                                      Optional<std::string&> reasonIfUnsupported,
-                                      const TensorInfo* inputLayerNormWeights,
-                                      const TensorInfo* forgetLayerNormWeights,
-                                      const TensorInfo* cellLayerNormWeights,
-                                      const TensorInfo* outputLayerNormWeights) const
+                                      const LstmInputParamsInfo& paramsInfo,
+                                      Optional<std::string&> reasonIfUnsupported) const
 {
     ignore_unused(descriptor);
-    ignore_unused(inputToForgetWeights);
-    ignore_unused(inputToCellWeights);
-    ignore_unused(inputToOutputWeights);
-    ignore_unused(recurrentToForgetWeights);
-    ignore_unused(recurrentToCellWeights);
-    ignore_unused(recurrentToOutputWeights);
-    ignore_unused(forgetGateBias);
-    ignore_unused(cellBias);
-    ignore_unused(outputGateBias);
-    ignore_unused(inputToInputWeights);
-    ignore_unused(recurrentToInputWeights);
-    ignore_unused(cellToInputWeights);
-    ignore_unused(inputGateBias);
-    ignore_unused(projectionWeights);
-    ignore_unused(projectionBias);
-    ignore_unused(cellToForgetWeights);
-    ignore_unused(cellToOutputWeights);
-    ignore_unused(inputLayerNormWeights);
-    ignore_unused(forgetLayerNormWeights);
-    ignore_unused(cellLayerNormWeights);
-    ignore_unused(outputLayerNormWeights);
+    ignore_unused(paramsInfo);
 
     bool supported = true;
 
@@ -977,26 +937,91 @@
         DataType::QuantisedSymm16
     };
 
+    // check inputs and outputs
     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
                                   "Reference Lstm: input is not a supported type.");
-
     supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
                                   "Reference Lstm: input and outputStateIn types are mismatched");
-
     supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
                                   "Reference Lstm: input and cellStateIn types are mismatched");
-
     supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
                                   "Reference Lstm: input and scratchBuffer types are mismatched");
-
     supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
                                   "Reference Lstm: input and outputStateOut types are mismatched");
-
     supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
                                   "Reference Lstm: input and cellStateOut types are mismatched");
-
     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
                                   "Reference Lstm: input and output types are mismatched");
+    // check layer parameters
+    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToForgetWeights()), reasonIfUnsupported,
+                                  "Reference Lstm: input and InputToForgetWeights types are mismatched");
+    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToCellWeights()), reasonIfUnsupported,
+                                  "Reference Lstm: input and InputToCellWeights types are mismatched");
+    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToOutputWeights()), reasonIfUnsupported,
+                                  "Reference Lstm: input and InputToOutputWeights types are mismatched");
+    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToForgetWeights()), reasonIfUnsupported,
+                                  "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
+    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToCellWeights()), reasonIfUnsupported,
+                                  "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
+    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToOutputWeights()), reasonIfUnsupported,
+                                  "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
+    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ForgetGateBias()), reasonIfUnsupported,
+                                  "Reference Lstm: input and ForgetGateBias types are mismatched");
+    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellBias()), reasonIfUnsupported,
+                                  "Reference Lstm: input and CellBias types are mismatched");
+    supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_OutputGateBias()), reasonIfUnsupported,
+                                  "Reference Lstm: input and OutputGateBias types are mismatched");
+    if (!descriptor.m_CifgEnabled)
+    {
+        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToInputWeights()), reasonIfUnsupported,
+                                      "Reference Lstm: input and InputToInputWeights types are mismatched");
+        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToInputWeights()),
+                                      reasonIfUnsupported,
+                                      "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
+        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputGateBias()), reasonIfUnsupported,
+                                      "Reference Lstm: input and InputGateBias types are mismatched");
+        if (descriptor.m_PeepholeEnabled)
+        {
+            supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToInputWeights()),
+                                          reasonIfUnsupported,
+                                          "Reference Lstm: input and CellToInputWeights types are mismatched");
+        }
+    }
+    if (descriptor.m_PeepholeEnabled)
+    {
+        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToForgetWeights()), reasonIfUnsupported,
+                                      "Reference Lstm: input and CellToForgetWeights types are mismatched");
+        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToOutputWeights()), reasonIfUnsupported,
+                                      "Reference Lstm: input and CellToOutputWeights types are mismatched");
+    }
+    if (descriptor.m_ProjectionEnabled)
+    {
+        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ProjectionWeights()), reasonIfUnsupported,
+                                      "Reference Lstm: input and mProjectionWeights types are mismatched");
+        if (paramsInfo.m_ProjectionBias != nullptr)
+        {
+            supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ProjectionBias()), reasonIfUnsupported,
+                                          "Reference Lstm: input and ProjectionBias types are mismatched");
+        }
+    }
+    if (descriptor.m_LayerNormEnabled)
+    {
+        if (!descriptor.m_CifgEnabled)
+        {
+            supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputLayerNormWeights()),
+                                          reasonIfUnsupported,
+                                          "Reference Lstm: input and InputLayerNormWeights types are mismatched");
+        }
+        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ForgetLayerNormWeights()),
+                                      reasonIfUnsupported,
+                                      "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
+        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellLayerNormWeights()),
+                                      reasonIfUnsupported,
+                                      "Reference Lstm: input and CellLayerNormWeights types are mismatched");
+        supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_OutputLayerNormWeights()),
+                                      reasonIfUnsupported,
+                                      "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
+    }
 
     return supported;
 }
diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp
index ead4d1c..c0bf188 100644
--- a/src/backends/reference/RefLayerSupport.hpp
+++ b/src/backends/reference/RefLayerSupport.hpp
@@ -138,28 +138,8 @@
                          const TensorInfo& cellStateOut,
                          const TensorInfo& output,
                          const LstmDescriptor& descriptor,
-                         const TensorInfo& inputToForgetWeights,
-                         const TensorInfo& inputToCellWeights,
-                         const TensorInfo& inputToOutputWeights,
-                         const TensorInfo& recurrentToForgetWeights,
-                         const TensorInfo& recurrentToCellWeights,
-                         const TensorInfo& recurrentToOutputWeights,
-                         const TensorInfo& forgetGateBias,
-                         const TensorInfo& cellBias,
-                         const TensorInfo& outputGateBias,
-                         const TensorInfo* inputToInputWeights,
-                         const TensorInfo* recurrentToInputWeights,
-                         const TensorInfo* cellToInputWeights,
-                         const TensorInfo* inputGateBias,
-                         const TensorInfo* projectionWeights,
-                         const TensorInfo* projectionBias,
-                         const TensorInfo* cellToForgetWeights,
-                         const TensorInfo* cellToOutputWeights,
-                         Optional<std::string&> reasonIfUnsupported = EmptyOptional(),
-                         const TensorInfo* inputLayerNormWeights = nullptr,
-                         const TensorInfo* forgetLayerNormWeights = nullptr,
-                         const TensorInfo* cellLayerNormWeights = nullptr,
-                         const TensorInfo* outputLayerNormWeights = nullptr) const override;
+                         const LstmInputParamsInfo& paramsInfo,
+                         Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
 
     bool IsMaximumSupported(const TensorInfo& input0,
                             const TensorInfo& input1,