blob: 9cd8774984908469d8a4c030f8fd6d2291a21f9b [file] [log] [blame]
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
namespace armnn
{
namespace
{
// Make a workload of the specified WorkloadType.
template<typename WorkloadType>
struct MakeWorkloadForType
{
template<typename QueueDescriptorType, typename... Args>
static std::unique_ptr<WorkloadType> Func(const QueueDescriptorType& descriptor,
const WorkloadInfo& info,
Args&&... args)
{
return std::make_unique<WorkloadType>(descriptor, info, std::forward<Args>(args)...);
}
};
// Specialization for void workload type used for unsupported workloads.
template<>
struct MakeWorkloadForType<NullWorkload>
{
template<typename QueueDescriptorType, typename... Args>
static std::unique_ptr<NullWorkload> Func(const QueueDescriptorType& descriptor,
const WorkloadInfo& info,
Args&&... args)
{
return nullptr;
}
};
// Makes a workload for one the specified types based on the data type requirements of the tensorinfo.
// Specify type void as the WorkloadType for unsupported DataType/WorkloadType combos.
template <typename Float16Workload, typename Float32Workload, typename Uint8Workload, typename Int32Workload,
typename BooleanWorkload, typename QueueDescriptorType, typename... Args>
std::unique_ptr<IWorkload> MakeWorkloadHelper(const QueueDescriptorType& descriptor,
const WorkloadInfo& info,
Args&&... args)
{
const DataType dataType = !info.m_InputTensorInfos.empty() ?
info.m_InputTensorInfos[0].GetDataType()
: info.m_OutputTensorInfos[0].GetDataType();
switch (dataType)
{
case DataType::Float16:
return MakeWorkloadForType<Float16Workload>::Func(descriptor, info, std::forward<Args>(args)...);
case DataType::Float32:
return MakeWorkloadForType<Float32Workload>::Func(descriptor, info, std::forward<Args>(args)...);
case DataType::QuantisedAsymm8:
return MakeWorkloadForType<Uint8Workload>::Func(descriptor, info, std::forward<Args>(args)...);
case DataType::Signed32:
return MakeWorkloadForType<Int32Workload>::Func(descriptor, info, std::forward<Args>(args)...);
case DataType::Boolean:
return MakeWorkloadForType<BooleanWorkload>::Func(descriptor, info, std::forward<Args>(args)...);
default:
BOOST_ASSERT_MSG(false, "Unknown DataType.");
return nullptr;
}
}
// Makes a workload for one the specified types based on the data type requirements of the tensorinfo.
// Calling this method is the equivalent of calling the five typed MakeWorkload method with <FloatWorkload,
// FloatWorkload, Uint8Workload, NullWorkload, NullWorkload>.
// Specify type void as the WorkloadType for unsupported DataType/WorkloadType combos.
template <typename FloatWorkload, typename Uint8Workload, typename QueueDescriptorType, typename... Args>
std::unique_ptr<IWorkload> MakeWorkloadHelper(const QueueDescriptorType& descriptor,
const WorkloadInfo& info,
Args&&... args)
{
return MakeWorkloadHelper<FloatWorkload, FloatWorkload, Uint8Workload, NullWorkload, NullWorkload>(
descriptor,
info,
std::forward<Args>(args)...);
}
} //namespace
} //namespace armnn