blob: 5014e7527b4b348ba412267a9c4b0a0e9e4eeb5c [file] [log] [blame]
//
// Copyright © 2020 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
#include "ConversionUtils_1_2.hpp"
using Half = half_float::half;
namespace armnn_driver
{
using namespace armnn;
using namespace android::nn;
template<typename HalPolicy,
typename HalOperation = typename HalPolicy::Operation,
typename HalModel = typename HalPolicy::Model>
bool ConvertElu(const HalOperation& operation, const HalModel& model, ConversionData& data)
{
using HalOperandType = typename HalPolicy::OperandType;
LayerInputHandle input0 = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
if (!input0.IsValid())
{
return Fail("%s: Operation has invalid inputs", __func__);
}
// Determine data type of input tensor
HalOperandType inputType;
if (!GetOperandType<HalPolicy>(operation, 0, model, inputType))
{
return Fail("%s: Operation has invalid inputs", __func__);
}
ActivationDescriptor desc;
desc.m_Function = ActivationFunction::Elu;
// Read alpha
if (inputType == HalOperandType::TENSOR_FLOAT16)
{
Half alpha;
if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT16, alpha, model, data))
{
return Fail("%s: Operation has invalid inputs (FLOAT16)", __func__);
}
desc.m_A = static_cast<float>(alpha);
}
else if (inputType == HalOperandType::TENSOR_FLOAT32)
{
if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT32, desc.m_A, model, data))
{
return Fail("%s: Operation has invalid inputs (FLOAT32)", __func__);
}
}
else
{
return Fail("%s: Unsupported input tensor type: %d", __func__, inputType);
}
return ::ConvertToActivation<HalPolicy>(operation, __func__, desc, model, data);
}
} // armnn_driver namespace