IVGCVSW-2814 Extensive ref IsSupported for Activation & Addition

Change-Id: Ib1a795eb129de1ec3f02807a2dff7613d7c6c28d
Signed-off-by: Derek Lamberti <derek.lamberti@arm.com>
diff --git a/include/armnn/Tensor.hpp b/include/armnn/Tensor.hpp
index 9380a96..589d849 100644
--- a/include/armnn/Tensor.hpp
+++ b/include/armnn/Tensor.hpp
@@ -80,6 +80,7 @@
     int32_t GetQuantizationOffset() const           { return m_Quantization.m_Offset; }
     void SetQuantizationScale(float scale)          { m_Quantization.m_Scale = scale; }
     void SetQuantizationOffset(int32_t offset)      { m_Quantization.m_Offset = offset; }
+    bool IsQuantized() const                        { return m_DataType == DataType::QuantisedAsymm8; }
 
     unsigned int GetNumBytes() const;
 
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 4b32a89..cdc6aca 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -9,11 +9,16 @@
 #include <InternalTypes.hpp>
 #include <LayerSupportCommon.hpp>
 #include <armnn/Types.hpp>
+#include <armnn/Descriptors.hpp>
 
 #include <backendsCommon/BackendRegistry.hpp>
 
 #include <boost/core/ignore_unused.hpp>
 
+#include <vector>
+#include <algorithm>
+#include <array>
+
 using namespace boost;
 
 namespace armnn
@@ -41,17 +46,171 @@
 
 } // anonymous namespace
 
+
+namespace
+{
+template<typename F>
+bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
+{
+    bool supported = rule();
+    if (!supported && reason)
+    {
+        reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
+    }
+    return supported;
+}
+
+struct Rule
+{
+    bool operator()() const
+    {
+        return m_Res;
+    }
+
+    bool m_Res = true;
+};
+
+template<class none = void>
+bool AllTypesAreEqualImpl()
+{
+    return true;
+}
+
+template<typename T, typename... Rest>
+bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
+{
+    static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
+
+    return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(rest...);
+}
+
+struct TypesAreEqual : public Rule
+{
+    template<typename ... Ts>
+    TypesAreEqual(const Ts&... ts)
+    {
+        m_Res = AllTypesAreEqualImpl(ts...);
+    }
+};
+
+struct QuantizationParametersAreEqual : public Rule
+{
+    QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
+    {
+        m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
+                info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
+    }
+};
+
+struct TypeAnyOf : public Rule
+{
+    template<typename Container>
+    TypeAnyOf(const TensorInfo& info, const Container& c)
+    {
+        m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
+            {
+                return dt == info.GetDataType();
+            });
+    }
+};
+
+struct ShapesAreSameRank : public Rule
+{
+    ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
+    {
+        m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
+    }
+};
+
+struct ShapesAreBroadcastCompatible : public Rule
+{
+    unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
+    {
+        unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
+        unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
+        return sizeIn;
+    }
+
+    ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
+    {
+        const TensorShape& shape0 = in0.GetShape();
+        const TensorShape& shape1 = in1.GetShape();
+        const TensorShape& outShape = out.GetShape();
+
+        for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
+        {
+            unsigned int sizeOut = outShape[i];
+            unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
+            unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
+
+            m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
+                     ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
+        }
+    }
+};
+} // namespace
+
+
 bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
                                             const TensorInfo& output,
                                             const ActivationDescriptor& descriptor,
                                             Optional<std::string&> reasonIfUnsupported) const
 {
-    ignore_unused(output);
-    ignore_unused(descriptor);
-    return IsSupportedForDataTypeRef(reasonIfUnsupported,
-                                     input.GetDataType(),
-                                     &TrueFunc<>,
-                                     &TrueFunc<>);
+   bool supported = true;
+
+    // Define supported types.
+    std::array<DataType,2> supportedTypes = {
+        DataType::Float32,
+        DataType::QuantisedAsymm8
+    };
+
+    supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
+                                  "Reference activation: input type not supported.");
+
+    supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+                                  "Reference activation: output type not supported.");
+
+    supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
+                                  "Reference activation: input and output types mismatched.");
+
+    supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
+                                  "Reference activation: input and output shapes are of different rank.");
+
+
+    struct ActivationFunctionSupported : public Rule
+    {
+        ActivationFunctionSupported(const ActivationDescriptor& desc)
+        {
+            switch(desc.m_Function)
+            {
+                case ActivationFunction::Abs:
+                case ActivationFunction::BoundedReLu:
+                case ActivationFunction::LeakyReLu:
+                case ActivationFunction::Linear:
+                case ActivationFunction::ReLu:
+                case ActivationFunction::Sigmoid:
+                case ActivationFunction::SoftReLu:
+                case ActivationFunction::Sqrt:
+                case ActivationFunction::Square:
+                case ActivationFunction::TanH:
+                {
+                    m_Res = true;
+                    break;
+                }
+                default:
+                {
+                    m_Res = false;
+                    break;
+                }
+            }
+        }
+    };
+
+    // Function is supported
+    supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
+                                  "Reference activation: function not supported.");
+
+    return supported;
 }
 
 bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
@@ -59,12 +218,32 @@
                                           const TensorInfo& output,
                                           Optional<std::string&> reasonIfUnsupported) const
 {
-    ignore_unused(input1);
-    ignore_unused(output);
-    return IsSupportedForDataTypeRef(reasonIfUnsupported,
-                                     input0.GetDataType(),
-                                     &TrueFunc<>,
-                                     &TrueFunc<>);
+    bool supported = true;
+
+    std::array<DataType,2> supportedTypes = {
+        DataType::Float32,
+        DataType::QuantisedAsymm8
+    };
+
+    supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
+                                  "Reference addition: input 0 is not a supported type.");
+
+    supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
+                                  "Reference addition: input 1 is not a supported type.");
+
+    supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+                                  "Reference addition: output is not a supported type.");
+
+    supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
+                                  "Reference addition: input 0 and Input 1 types are mismatched");
+
+    supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
+                                  "Reference addition: input and output types are mismatched");
+
+    supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
+                                  "Reference addition: shapes are not suitable for implicit broadcast.");
+
+    return supported;
 }
 
 bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
diff --git a/src/backends/reference/test/RefLayerSupportTests.cpp b/src/backends/reference/test/RefLayerSupportTests.cpp
index b7fbc68..2c7e17d 100644
--- a/src/backends/reference/test/RefLayerSupportTests.cpp
+++ b/src/backends/reference/test/RefLayerSupportTests.cpp
@@ -9,6 +9,7 @@
 
 #include <backendsCommon/CpuTensorHandle.hpp>
 #include <reference/RefWorkloadFactory.hpp>
+#include <reference/RefLayerSupport.hpp>
 #include <backendsCommon/test/LayerTests.hpp>
 #include <backendsCommon/test/IsLayerSupportedTestImpl.hpp>
 
@@ -32,6 +33,20 @@
 {
     LayerTypeMatchesTest();
 }
+BOOST_AUTO_TEST_CASE(IsLayerSupportedReferenceAddition)
+{
+    armnn::TensorShape shape0 = {1,1,3,4};
+    armnn::TensorShape shape1 = {4};
+    armnn::TensorShape outShape = {1,1,3,4};
+    armnn::TensorInfo in0(shape0, armnn::DataType::Float32);
+    armnn::TensorInfo in1(shape1, armnn::DataType::Float32);
+    armnn::TensorInfo out(outShape, armnn::DataType::Float32);
+
+    armnn::RefLayerSupport supportChecker;
+    std::string reasonNotSupported;
+    BOOST_CHECK(supportChecker.IsAdditionSupported(in0, in1, out, reasonNotSupported));
+}
+
 
 BOOST_AUTO_TEST_CASE(IsLayerSupportedFloat16Reference)
 {