blob: 08189f9999852fe9e00a6ed18c88a7d7e6f9339f [file] [log] [blame]
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
#include <boost/assert.hpp>
#include <algorithm>
namespace armnn
{
inline armnn::Optional<armnn::DataType> GetBiasTypeFromWeightsType(armnn::Optional<armnn::DataType> weightsType)
{
if (!weightsType)
{
return weightsType;
}
switch(weightsType.value())
{
case armnn::DataType::Float16:
case armnn::DataType::Float32:
return weightsType;
case armnn::DataType::QuantisedAsymm8:
return armnn::DataType::Signed32;
case armnn::DataType::QuantisedSymm16:
return armnn::DataType::Signed32;
default:
BOOST_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
}
return armnn::EmptyOptional();
}
template<typename F>
bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
{
bool supported = rule();
if (!supported && reason)
{
reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
}
return supported;
}
struct Rule
{
bool operator()() const
{
return m_Res;
}
bool m_Res = true;
};
template<typename T>
bool AllTypesAreEqualImpl(T)
{
return true;
}
template<typename T, typename... Rest>
bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
{
static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
}
struct TypesAreEqual : public Rule
{
template<typename ... Ts>
TypesAreEqual(const Ts&... ts)
{
m_Res = AllTypesAreEqualImpl(ts...);
}
};
struct QuantizationParametersAreEqual : public Rule
{
QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
{
m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
}
};
struct TypeAnyOf : public Rule
{
template<typename Container>
TypeAnyOf(const TensorInfo& info, const Container& c)
{
m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
{
return dt == info.GetDataType();
});
}
};
struct TypeIs : public Rule
{
TypeIs(const TensorInfo& info, DataType dt)
{
m_Res = dt == info.GetDataType();
}
};
struct BiasAndWeightsTypesMatch : public Rule
{
BiasAndWeightsTypesMatch(const TensorInfo& biases, const TensorInfo& weights)
{
m_Res = biases.GetDataType() == GetBiasTypeFromWeightsType(weights.GetDataType()).value();
}
};
struct BiasAndWeightsTypesCompatible : public Rule
{
template<typename Container>
BiasAndWeightsTypesCompatible(const TensorInfo& info, const Container& c)
{
m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
{
return dt == GetBiasTypeFromWeightsType(info.GetDataType()).value();
});
}
};
struct ShapesAreSameRank : public Rule
{
ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
{
m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
}
};
struct ShapesAreSameTotalSize : public Rule
{
ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
{
m_Res = info0.GetNumElements() == info1.GetNumElements();
}
};
struct ShapesAreBroadcastCompatible : public Rule
{
unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
{
unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
return sizeIn;
}
ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
{
const TensorShape& shape0 = in0.GetShape();
const TensorShape& shape1 = in1.GetShape();
const TensorShape& outShape = out.GetShape();
for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
{
unsigned int sizeOut = outShape[i];
unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
((sizeIn1 == sizeOut) || (sizeIn1 == 1));
}
}
};
struct TensorNumDimensionsAreCorrect : public Rule
{
TensorNumDimensionsAreCorrect(const TensorInfo& info, unsigned int expectedNumDimensions)
{
m_Res = info.GetNumDimensions() == expectedNumDimensions;
}
};
} //namespace armnn