blob: adf264aabbcc24749fa06a3f874a45a8bfb6f7f9 [file] [log] [blame]
//
// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
#include "DelegateOptions.hpp"
#include <tensorflow/lite/builtin_ops.h>
#include <tensorflow/lite/c/builtin_op_data.h>
#include <tensorflow/lite/c/common.h>
#include <tensorflow/lite/minimal_logging.h>
namespace armnnDelegate
{
struct DelegateData
{
DelegateData(const std::vector<armnn::BackendId>& backends)
: m_Backends(backends)
, m_Network(nullptr, nullptr)
{}
const std::vector<armnn::BackendId> m_Backends;
armnn::INetworkPtr m_Network;
std::vector<armnn::IOutputSlot*> m_OutputSlotForNode;
};
// Forward decleration for functions initializing the ArmNN Delegate
DelegateOptions TfLiteArmnnDelegateOptionsDefault();
TfLiteDelegate* TfLiteArmnnDelegateCreate(armnnDelegate::DelegateOptions options);
void TfLiteArmnnDelegateDelete(TfLiteDelegate* tfLiteDelegate);
TfLiteStatus DoPrepare(TfLiteContext* context, TfLiteDelegate* delegate);
/// ArmNN Delegate
class Delegate
{
friend class ArmnnSubgraph;
public:
explicit Delegate(armnnDelegate::DelegateOptions options);
TfLiteIntArray* IdentifyOperatorsToDelegate(TfLiteContext* context);
TfLiteDelegate* GetDelegate();
private:
TfLiteDelegate m_Delegate = {
reinterpret_cast<void*>(this), // .data_
DoPrepare, // .Prepare
nullptr, // .CopyFromBufferHandle
nullptr, // .CopyToBufferHandle
nullptr, // .FreeBufferHandle
kTfLiteDelegateFlagsNone, // .flags
};
/// ArmNN Runtime pointer
armnn::IRuntimePtr m_Runtime;
/// ArmNN Delegate Options
armnnDelegate::DelegateOptions m_Options;
};
/// ArmnnSubgraph class where parsing the nodes to ArmNN format and creating the ArmNN Graph
class ArmnnSubgraph
{
public:
static ArmnnSubgraph* Create(TfLiteContext* tfLiteContext,
const TfLiteDelegateParams* parameters,
const Delegate* delegate);
TfLiteStatus Prepare(TfLiteContext* tfLiteContext);
TfLiteStatus Invoke(TfLiteContext* tfLiteContext, TfLiteNode* tfLiteNode);
static TfLiteStatus VisitNode(DelegateData& delegateData,
TfLiteContext* tfLiteContext,
TfLiteRegistration* tfLiteRegistration,
TfLiteNode* tfLiteNode,
int nodeIndex);
private:
ArmnnSubgraph(armnn::NetworkId networkId,
armnn::IRuntime* runtime,
std::vector<armnn::BindingPointInfo>& inputBindings,
std::vector<armnn::BindingPointInfo>& outputBindings)
: m_NetworkId(networkId), m_Runtime(runtime), m_InputBindings(inputBindings), m_OutputBindings(outputBindings)
{}
static TfLiteStatus AddInputLayer(DelegateData& delegateData,
TfLiteContext* tfLiteContext,
const TfLiteIntArray* inputs,
std::vector<armnn::BindingPointInfo>& inputBindings);
static TfLiteStatus AddOutputLayer(DelegateData& delegateData,
TfLiteContext* tfLiteContext,
const TfLiteIntArray* outputs,
std::vector<armnn::BindingPointInfo>& outputBindings);
/// The Network Id
armnn::NetworkId m_NetworkId;
/// ArmNN Rumtime
armnn::IRuntime* m_Runtime;
// Binding information for inputs and outputs
std::vector<armnn::BindingPointInfo> m_InputBindings;
std::vector<armnn::BindingPointInfo> m_OutputBindings;
};
} // armnnDelegate namespace