IVGCVSW-3145 Refactor Reference Reshape workloads

* Removed reference reshape workloads for float32 and uint8
* Added RefReshapeWorkload
* Added check for supported datatypes for reshape in WorkloadData
* Added check for supported datatypes for reshape in RefLayerSupport
* Updated CMakeLists.txt
* Updated references to reshape workloads

Signed-off-by: Nina Drozd <nina.drozd@arm.com>
Change-Id: I9941659067b022f8f7686ab0ff14776944dca3e5
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index d9779e4..ea84c0b 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -850,13 +850,13 @@
 
     // Check the supported data types
     std::vector<DataType> supportedTypes =
-                              {
-                                  DataType::Float32,
-                                  DataType::Float16,
-                                  DataType::Signed32,
-                                  DataType::QuantisedAsymm8,
-                                  DataType::QuantisedSymm16
-                              };
+    {
+        DataType::Float32,
+        DataType::Float16,
+        DataType::Signed32,
+        DataType::QuantisedAsymm8,
+        DataType::QuantisedSymm16
+    };
 
     ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0], supportedTypes, "ConstantQueueDescriptor");
 }
@@ -872,6 +872,17 @@
             to_string(workloadInfo.m_InputTensorInfos[0].GetNumElements()) + " but output tensor has " +
             to_string(workloadInfo.m_OutputTensorInfos[0].GetNumElements()) + " elements.");
     }
+
+    // Check the supported data types
+    std::vector<DataType> supportedTypes =
+    {
+        DataType::Float32,
+        DataType::Float16,
+        DataType::QuantisedAsymm8
+    };
+
+    ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, "ReshapeQueueDescriptor");
+    ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0], supportedTypes, "ReshapeQueueDescriptor");
 }
 
 void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
diff --git a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
index 119eb7d..067cca8 100644
--- a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
+++ b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
@@ -447,7 +447,7 @@
     AddOutputToWorkload(invalidData, invalidInfo, outputTensorInfo, nullptr);
 
     // InvalidArgumentException is expected, because the number of elements don't match.
-    BOOST_CHECK_THROW(RefReshapeFloat32Workload(invalidData, invalidInfo), armnn::InvalidArgumentException);
+    BOOST_CHECK_THROW(RefReshapeWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException);
 }
 
 
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 9be1ed6..2adcb10 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -1021,10 +1021,15 @@
                                          Optional<std::string&> reasonIfUnsupported) const
 {
     ignore_unused(descriptor);
-    return IsSupportedForDataTypeRef(reasonIfUnsupported,
-                                     input.GetDataType(),
-                                     &TrueFunc<>,
-                                     &TrueFunc<>);
+    // Define supported output types.
+    std::array<DataType,3> supportedOutputTypes =
+    {
+        DataType::Float32,
+        DataType::Float16,
+        DataType::QuantisedAsymm8
+    };
+    return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
+        "Reference reshape: input type not supported.");
 }
 
 bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 6abcf9c..1243328 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -264,7 +264,7 @@
 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor,
     const WorkloadInfo& info) const
 {
-    return MakeWorkload<RefReshapeFloat32Workload, RefReshapeUint8Workload>(descriptor, info);
+    return std::make_unique<RefReshapeWorkload>(descriptor, info);
 }
 
 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& descriptor,
diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk
index 50cfbf6..1c7f8dc 100644
--- a/src/backends/reference/backend.mk
+++ b/src/backends/reference/backend.mk
@@ -54,8 +54,7 @@
         workloads/RefPooling2dFloat32Workload.cpp \
         workloads/RefPooling2dUint8Workload.cpp \
         workloads/RefQuantizeWorkload.cpp \
-        workloads/RefReshapeFloat32Workload.cpp \
-        workloads/RefReshapeUint8Workload.cpp \
+        workloads/RefReshapeWorkload.cpp \
         workloads/RefResizeBilinearFloat32Workload.cpp \
         workloads/RefResizeBilinearUint8Workload.cpp \
         workloads/RefRsqrtFloat32Workload.cpp \
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp
index 95da7ab..3f4cc75 100644
--- a/src/backends/reference/test/RefCreateWorkloadTests.cpp
+++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp
@@ -663,12 +663,12 @@
 
 BOOST_AUTO_TEST_CASE(CreateReshapeFloat32Workload)
 {
-    RefCreateReshapeWorkloadTest<RefReshapeFloat32Workload, armnn::DataType::Float32>();
+    RefCreateReshapeWorkloadTest<RefReshapeWorkload, armnn::DataType::Float32>();
 }
 
 BOOST_AUTO_TEST_CASE(CreateReshapeUint8Workload)
 {
-    RefCreateReshapeWorkloadTest<RefReshapeUint8Workload, armnn::DataType::QuantisedAsymm8>();
+    RefCreateReshapeWorkloadTest<RefReshapeWorkload, armnn::DataType::QuantisedAsymm8>();
 }
 
 template <typename MergerWorkloadType, armnn::DataType DataType>
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index 7f26d78..508dfdc 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -91,10 +91,8 @@
     RefPooling2dUint8Workload.hpp
     RefQuantizeWorkload.cpp
     RefQuantizeWorkload.hpp
-    RefReshapeFloat32Workload.cpp
-    RefReshapeFloat32Workload.hpp
-    RefReshapeUint8Workload.cpp
-    RefReshapeUint8Workload.hpp
+    RefReshapeWorkload.cpp
+    RefReshapeWorkload.hpp
     RefResizeBilinearFloat32Workload.cpp
     RefResizeBilinearFloat32Workload.hpp
     RefResizeBilinearUint8Workload.cpp
diff --git a/src/backends/reference/workloads/RefReshapeFloat32Workload.cpp b/src/backends/reference/workloads/RefReshapeFloat32Workload.cpp
deleted file mode 100644
index 99c94a4..0000000
--- a/src/backends/reference/workloads/RefReshapeFloat32Workload.cpp
+++ /dev/null
@@ -1,27 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "RefReshapeFloat32Workload.hpp"
-
-#include "RefWorkloadUtils.hpp"
-
-#include "Profiling.hpp"
-
-#include <cstring>
-
-namespace armnn
-{
-
-void RefReshapeFloat32Workload::Execute() const
-{
-    ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefReshapeFloat32Workload_Execute");
-
-    void* output = GetOutputTensorData<void>(0, m_Data);
-    const void* input = GetInputTensorData<void>(0, m_Data);
-    unsigned int numBytes = GetTensorInfo(m_Data.m_Inputs[0]).GetNumBytes();
-    memcpy(output, input, numBytes);
-}
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefReshapeFloat32Workload.hpp b/src/backends/reference/workloads/RefReshapeFloat32Workload.hpp
deleted file mode 100644
index 75024b3..0000000
--- a/src/backends/reference/workloads/RefReshapeFloat32Workload.hpp
+++ /dev/null
@@ -1,21 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include <backendsCommon/Workload.hpp>
-#include <backendsCommon/WorkloadData.hpp>
-
-namespace armnn
-{
-
-class RefReshapeFloat32Workload : public Float32Workload<ReshapeQueueDescriptor>
-{
-public:
-    using Float32Workload<ReshapeQueueDescriptor>::Float32Workload;
-    virtual void Execute() const override;
-};
-
-} //namespace armnn
diff --git a/src/backends/reference/workloads/RefReshapeUint8Workload.cpp b/src/backends/reference/workloads/RefReshapeWorkload.cpp
similarity index 71%
rename from src/backends/reference/workloads/RefReshapeUint8Workload.cpp
rename to src/backends/reference/workloads/RefReshapeWorkload.cpp
index 8f475f3..6d29781 100644
--- a/src/backends/reference/workloads/RefReshapeUint8Workload.cpp
+++ b/src/backends/reference/workloads/RefReshapeWorkload.cpp
@@ -1,12 +1,10 @@
-//
+//
 // Copyright © 2017 Arm Ltd. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
 
-#include "RefReshapeUint8Workload.hpp"
-
+#include "RefReshapeWorkload.hpp"
 #include "RefWorkloadUtils.hpp"
-
 #include "Profiling.hpp"
 
 #include <cstring>
@@ -14,9 +12,9 @@
 namespace armnn
 {
 
-void RefReshapeUint8Workload::Execute() const
+void RefReshapeWorkload::Execute() const
 {
-    ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefReshapeUint8Workload_Execute");
+    ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefReshapeWorkload_Execute");
 
     void* output = GetOutputTensorData<void>(0, m_Data);
     const void* input = GetInputTensorData<void>(0, m_Data);
diff --git a/src/backends/reference/workloads/RefReshapeUint8Workload.hpp b/src/backends/reference/workloads/RefReshapeWorkload.hpp
similarity index 65%
rename from src/backends/reference/workloads/RefReshapeUint8Workload.hpp
rename to src/backends/reference/workloads/RefReshapeWorkload.hpp
index c3d31f8..7359ff9 100644
--- a/src/backends/reference/workloads/RefReshapeUint8Workload.hpp
+++ b/src/backends/reference/workloads/RefReshapeWorkload.hpp
@@ -1,4 +1,4 @@
-//
+//
 // Copyright © 2017 Arm Ltd. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
@@ -11,10 +11,10 @@
 namespace armnn
 {
 
-class RefReshapeUint8Workload : public Uint8Workload<ReshapeQueueDescriptor>
+class RefReshapeWorkload : public BaseWorkload<ReshapeQueueDescriptor>
 {
 public:
-    using Uint8Workload<ReshapeQueueDescriptor>::Uint8Workload;
+    using BaseWorkload<ReshapeQueueDescriptor>::BaseWorkload;
     virtual void Execute() const override;
 };
 
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index 54bc5c7..20649d9 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -23,14 +23,12 @@
 #include "TensorBufferArrayView.hpp"
 #include "RefBatchNormalizationFloat32Workload.hpp"
 #include "Splitter.hpp"
-#include "RefReshapeFloat32Workload.hpp"
 #include "RefDepthwiseConvolution2dWorkload.hpp"
 #include "FullyConnected.hpp"
 #include "Gather.hpp"
 #include "RefFloorFloat32Workload.hpp"
 #include "RefSoftmaxFloat32Workload.hpp"
 #include "RefSoftmaxUint8Workload.hpp"
-#include "RefReshapeUint8Workload.hpp"
 #include "RefResizeBilinearFloat32Workload.hpp"
 #include "RefBatchNormalizationUint8Workload.hpp"
 #include "ResizeBilinear.hpp"
@@ -59,3 +57,4 @@
 #include "RefRsqrtFloat32Workload.hpp"
 #include "RefDequantizeWorkload.hpp"
 #include "RefQuantizeWorkload.hpp"
+#include "RefReshapeWorkload.hpp"