blob: 91dd3b17b796d9a69341fc6ca0274dc306256df9 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_TOOLS_DELEGATES_DELEGATE_PROVIDER_H_
#define TENSORFLOW_LITE_TOOLS_DELEGATES_DELEGATE_PROVIDER_H_
#include <string>
#include <vector>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/tools/command_line_flags.h"
#include "tensorflow/lite/tools/logging.h"
#include "tensorflow/lite/tools/tool_params.h"
namespace tflite {
namespace tools {
// Same w/ Interpreter::TfLiteDelegatePtr to avoid pulling
// tensorflow/lite/interpreter.h dependency
using TfLiteDelegatePtr =
std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>;
class DelegateProvider {
public:
virtual ~DelegateProvider() {}
// Create a list of command-line parsable flags based on tool params inside
// 'params' whose value will be set to the corresponding runtime flag value.
virtual std::vector<Flag> CreateFlags(ToolParams* params) const = 0;
// Log tool params.
virtual void LogParams(const ToolParams& params) const = 0;
// Create a TfLiteDelegate based on tool params.
virtual TfLiteDelegatePtr CreateTfLiteDelegate(
const ToolParams& params) const = 0;
virtual std::string GetName() const = 0;
const ToolParams& DefaultParams() const { return default_params_; }
protected:
template <typename T>
Flag CreateFlag(const char* name, ToolParams* params,
const std::string& usage) const {
return Flag(
name, [params, name](const T& val) { params->Set<T>(name, val); },
default_params_.Get<T>(name), usage, Flag::kOptional);
}
ToolParams default_params_;
};
using DelegateProviderPtr = std::unique_ptr<DelegateProvider>;
using DelegateProviderList = std::vector<DelegateProviderPtr>;
class DelegateProviderRegistrar {
public:
template <typename T>
struct Register {
Register() {
auto* const instance = DelegateProviderRegistrar::GetSingleton();
instance->providers_.emplace_back(DelegateProviderPtr(new T()));
}
};
static const DelegateProviderList& GetProviders() {
return GetSingleton()->providers_;
}
private:
DelegateProviderRegistrar() {}
DelegateProviderRegistrar(const DelegateProviderRegistrar&) = delete;
DelegateProviderRegistrar& operator=(const DelegateProviderRegistrar&) =
delete;
static DelegateProviderRegistrar* GetSingleton() {
static auto* instance = new DelegateProviderRegistrar();
return instance;
}
DelegateProviderList providers_;
};
#define REGISTER_DELEGATE_PROVIDER_VNAME(T) gDelegateProvider_##T##_
#define REGISTER_DELEGATE_PROVIDER(T) \
static DelegateProviderRegistrar::Register<T> \
REGISTER_DELEGATE_PROVIDER_VNAME(T);
// A global helper function to get all registered delegate providers.
inline const DelegateProviderList& GetRegisteredDelegateProviders() {
return DelegateProviderRegistrar::GetProviders();
}
} // namespace tools
} // namespace tflite
#endif // TENSORFLOW_LITE_TOOLS_DELEGATES_DELEGATE_PROVIDER_H_