blob: d333fa736e3217ddff0f0ed3155a5cdb31f3a563 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/interpreter.h"
#include <cassert>
#include <cstdarg>
#include <cstdint>
#include <cstring>
#include <utility>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/context_util.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/graph_info.h"
#include "tensorflow/lite/memory_planner.h"
#include "tensorflow/lite/minimal_logging.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/util.h"
// TODO(b/139446230): Move to portable platform header.
#if defined(__ANDROID__)
#define TFLITE_IS_MOBILE_PLATFORM
#endif // defined(__ANDROID__)
#if defined(__APPLE__)
#include "TargetConditionals.h"
#if TARGET_IPHONE_SIMULATOR
#define TFLITE_IS_MOBILE_PLATFORM
#elif TARGET_OS_IPHONE
#define TFLITE_IS_MOBILE_PLATFORM
#endif
#endif // defined(__APPLE__)
// TODO(b/132087118): move static_assert to c_api_internal when compiled with
// C++.
static_assert(sizeof(TfLiteFloat16) == sizeof(uint16_t),
"Float 16 type must be 16 bits.");
namespace tflite {
namespace {
// Gets the current TfLiteQuantization from the legacy TfLiteQuantizationParams.
TfLiteQuantization GetQuantizationFromLegacy(
const TfLiteQuantizationParams& legacy_quantization) {
TfLiteQuantization quantization;
quantization.type = kTfLiteAffineQuantization;
auto* affine_quantization = reinterpret_cast<TfLiteAffineQuantization*>(
malloc(sizeof(TfLiteAffineQuantization)));
affine_quantization->scale = TfLiteFloatArrayCreate(1);
affine_quantization->zero_point = TfLiteIntArrayCreate(1);
affine_quantization->scale->data[0] = legacy_quantization.scale;
affine_quantization->zero_point->data[0] = legacy_quantization.zero_point;
quantization.params = affine_quantization;
return quantization;
}
} // namespace
Interpreter::Interpreter(ErrorReporter* error_reporter)
: error_reporter_(error_reporter ? error_reporter
: DefaultErrorReporter()) {
// TODO(b/128420794): Include the TFLite runtime version in the log.
// Prod logging is useful for mobile platforms where scraping console logs is
// critical for debugging.
#if defined(TFLITE_IS_MOBILE_PLATFORM)
TFLITE_LOG_PROD_ONCE(TFLITE_LOG_INFO, "Initialized TensorFlow Lite runtime.");
#else
TFLITE_LOG_ONCE(TFLITE_LOG_INFO, "Initialized TensorFlow Lite runtime.");
#endif
// There's always at least 1 subgraph which is the primary subgraph.
AddSubgraphs(1);
context_ = primary_subgraph().context();
// Reserve some space for the tensors to avoid excessive resizing.
for (int i = 0; i < kTfLiteMaxExternalContexts; ++i) {
external_contexts_[i] = nullptr;
}
// This operation is cheap because we allocate the CPU context resources (i.e.
// threads) lazily.
own_external_cpu_backend_context_.reset(new ExternalCpuBackendContext());
external_contexts_[kTfLiteCpuBackendContext] =
own_external_cpu_backend_context_.get();
UseNNAPI(false);
}
Interpreter::~Interpreter() {
// The owned external Cpu Backend Context will go out of scope with this
// interpreter. If we have an external backend context that is not
// owned, we need to clear the cache for other interpreters that may
// use the context.
if (external_contexts_[kTfLiteCpuBackendContext] &&
(external_contexts_[kTfLiteCpuBackendContext] !=
own_external_cpu_backend_context_.get())) {
ExternalCpuBackendContext* external_context =
static_cast<ExternalCpuBackendContext*>(
external_contexts_[kTfLiteCpuBackendContext]);
TfLiteInternalBackendContext* internal_context =
external_context->internal_backend_context();
if (internal_context) {
// This call may have negative performance impacts on the next inference
// for any interpreter using this context. The cache will be refreshed
// by the next inference.
internal_context->ClearCaches();
}
}
}
void Interpreter::SetExternalContext(TfLiteExternalContextType type,
TfLiteExternalContext* ctx) {
if (ctx == own_external_cpu_backend_context_.get()) {
error_reporter_->Report(
"WARNING: The passed external context is identical to the internally "
"owned one.");
return;
}
// We have an internally owned external context of kTfLiteCpuBackendContext.
// If it's overwritten here, we will release the resource of the internally
// owned external context.
// Note: the 'max thread count' info associated with the overwritten context
// will be lost here, and such info is now detemined by the new context, thus
// affecting how much parallelism a TFLite op would have.
if (kTfLiteCpuBackendContext == type &&
external_contexts_[kTfLiteCpuBackendContext] ==
own_external_cpu_backend_context_.get()) {
own_external_cpu_backend_context_.reset();
}
// This essentially changes the "external_contexts_[type]".
primary_subgraph().SetExternalContext(type, ctx);
}
TfLiteStatus Interpreter::SetInputs(std::vector<int> inputs) {
return primary_subgraph().SetInputs(std::move(inputs));
}
TfLiteStatus Interpreter::SetOutputs(std::vector<int> outputs) {
return primary_subgraph().SetOutputs(std::move(outputs));
}
TfLiteStatus Interpreter::SetVariables(std::vector<int> variables) {
return primary_subgraph().SetVariables(std::move(variables));
}
TfLiteStatus Interpreter::AllocateTensors() {
return primary_subgraph().AllocateTensors();
}
void Interpreter::ReserveNodes(int count) {
primary_subgraph().ReserveNodes(count);
}
void Interpreter::AddSubgraphs(int subgraphs_to_add,
int* first_new_subgraph_index) {
const size_t base_index = subgraphs_.size();
if (first_new_subgraph_index) *first_new_subgraph_index = base_index;
subgraphs_.reserve(base_index + subgraphs_to_add);
for (int i = 0; i < subgraphs_to_add; ++i) {
Subgraph* subgraph = new Subgraph(error_reporter_, external_contexts_,
&subgraphs_, &resources_);
subgraphs_.emplace_back(subgraph);
}
}
TfLiteStatus Interpreter::AddNodeWithParameters(
const std::vector<int>& inputs, const std::vector<int>& outputs,
const char* init_data, size_t init_data_size, void* builtin_data,
const TfLiteRegistration* registration, int* node_index) {
return primary_subgraph().AddNodeWithParameters(
inputs, outputs, {}, init_data, init_data_size, builtin_data,
registration, node_index);
}
TfLiteStatus Interpreter::ResizeInputTensor(int tensor_index,
const std::vector<int>& dims) {
return primary_subgraph().ResizeInputTensor(tensor_index, dims);
}
TfLiteStatus Interpreter::ReleaseNonPersistentMemory() {
// TODO(b/138790287): We could do this for all subgraphs whose tensors have
// been allocated. However, AllocateTensors() relies on Control Flow ops to
// allocate tensors on 'children' subgraphs. Revisit this if required.
return primary_subgraph().ReleaseNonPersistentMemory();
}
TfLiteStatus Interpreter::Invoke() {
TF_LITE_ENSURE_STATUS(primary_subgraph().Invoke());
if (!allow_buffer_handle_output_) {
for (int tensor_index : outputs()) {
TF_LITE_ENSURE_STATUS(
primary_subgraph().EnsureTensorDataIsReadable(tensor_index));
}
}
return kTfLiteOk;
}
TfLiteStatus Interpreter::AddTensors(int tensors_to_add,
int* first_new_tensor_index) {
return primary_subgraph().AddTensors(tensors_to_add, first_new_tensor_index);
}
TfLiteStatus Interpreter::ResetVariableTensors() {
return primary_subgraph().ResetVariableTensors();
}
TfLiteStatus Interpreter::SetTensorParametersReadOnly(
int tensor_index, TfLiteType type, const char* name,
const std::vector<int>& dims, TfLiteQuantization quantization,
const char* buffer, size_t bytes, const Allocation* allocation) {
return primary_subgraph().SetTensorParametersReadOnly(
tensor_index, type, name, dims.size(), dims.data(), quantization, buffer,
bytes, allocation);
}
TfLiteStatus Interpreter::SetTensorParametersReadWrite(
int tensor_index, TfLiteType type, const char* name,
const std::vector<int>& dims, TfLiteQuantization quantization,
bool is_variable) {
return primary_subgraph().SetTensorParametersReadWrite(
tensor_index, type, name, dims.size(), dims.data(), quantization,
is_variable);
}
TfLiteStatus Interpreter::SetTensorParametersReadOnly(
int tensor_index, TfLiteType type, const char* name, const size_t rank,
const int* dims, TfLiteQuantizationParams quantization, const char* buffer,
size_t bytes, const Allocation* allocation) {
TfLiteQuantization new_quantization = GetQuantizationFromLegacy(quantization);
return primary_subgraph().SetTensorParametersReadOnly(
tensor_index, type, name, rank, dims, new_quantization, buffer, bytes,
allocation);
}
TfLiteStatus Interpreter::SetTensorParametersReadWrite(
int tensor_index, TfLiteType type, const char* name, const size_t rank,
const int* dims, TfLiteQuantizationParams quantization, bool is_variable) {
TfLiteQuantization new_quantization = GetQuantizationFromLegacy(quantization);
return primary_subgraph().SetTensorParametersReadWrite(
tensor_index, type, name, rank, dims, new_quantization, is_variable);
}
TfLiteStatus Interpreter::SetExecutionPlan(const std::vector<int>& new_plan) {
return primary_subgraph().SetExecutionPlan(new_plan);
}
void Interpreter::UseNNAPI(bool enable) { primary_subgraph().UseNNAPI(enable); }
void Interpreter::SetNumThreads(int num_threads) {
if (num_threads < -1) {
context_->ReportError(context_,
"num_threads should be >=0 or just -1 to let TFLite "
"runtime set the value.");
return;
}
for (auto& subgraph : subgraphs_) {
subgraph->context()->recommended_num_threads = num_threads;
}
for (int i = 0; i < kTfLiteMaxExternalContexts; ++i) {
auto* c = external_contexts_[i];
if (c && c->Refresh) {
c->Refresh(context_);
}
}
}
void Interpreter::SetAllowFp16PrecisionForFp32(bool allow) {
for (auto& subgraph : subgraphs_) {
subgraph->context()->allow_fp32_relax_to_fp16 = allow;
}
}
// TODO(b/121264966): Subgraphs added after cancellation is set will not get the
// cancellation function added to their context.
void Interpreter::SetCancellationFunction(void* data,
bool (*check_cancelled_func)(void*)) {
for (auto& subgraph : subgraphs_) {
subgraph->SetCancellationFunction(data, check_cancelled_func);
}
}
TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) {
for (auto& subgraph : subgraphs_) {
TF_LITE_ENSURE_OK(context_, subgraph->ModifyGraphWithDelegate(delegate));
}
return kTfLiteOk;
}
TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegatePtr delegate) {
// Note that we retain ownership of the delegate even if graph modification
// fails, as delegate use will be in an indeterminate state at that point.
owned_delegates_.push_back(std::move(delegate));
return ModifyGraphWithDelegate(owned_delegates_.back().get());
}
TfLiteStatus Interpreter::SetBufferHandle(int tensor_index,
TfLiteBufferHandle buffer_handle,
TfLiteDelegate* delegate) {
TF_LITE_ENSURE(context_, tensor_index < tensors_size());
std::vector<TfLiteTensor>& tensors = primary_subgraph().tensors();
TfLiteTensor* tensor = &tensors[tensor_index];
TF_LITE_ENSURE(context_,
tensor->delegate == nullptr || tensor->delegate == delegate);
tensor->delegate = delegate;
if (tensor->buffer_handle != kTfLiteNullBufferHandle) {
TF_LITE_ENSURE(context_, tensor->delegate->FreeBufferHandle != nullptr);
tensor->delegate->FreeBufferHandle(context_, tensor->delegate,
&tensor->buffer_handle);
}
tensor->buffer_handle = buffer_handle;
return kTfLiteOk;
}
TfLiteStatus Interpreter::GetBufferHandle(int tensor_index,
TfLiteBufferHandle* buffer_handle,
TfLiteDelegate** delegate) {
TF_LITE_ENSURE(context_, tensor_index < tensors_size());
std::vector<TfLiteTensor>& tensors = primary_subgraph().tensors();
TfLiteTensor* tensor = &tensors[tensor_index];
*delegate = tensor->delegate;
*buffer_handle = tensor->buffer_handle;
return kTfLiteOk;
}
void Interpreter::SetProfiler(Profiler* profiler) {
// Release resources occupied by owned_profiler_ which is replaced by
// caller-owned profiler.
owned_profiler_.reset(nullptr);
SetSubgraphProfiler(profiler);
}
void Interpreter::SetProfiler(std::unique_ptr<Profiler> profiler) {
owned_profiler_ = std::move(profiler);
SetSubgraphProfiler(owned_profiler_.get());
}
void Interpreter::SetSubgraphProfiler(Profiler* profiler) {
for (int subgraph_index = 0; subgraph_index < subgraphs_.size();
++subgraph_index) {
subgraphs_[subgraph_index]->SetProfiler(profiler, subgraph_index);
}
}
Profiler* Interpreter::GetProfiler() {
return primary_subgraph().GetProfiler();
}
} // namespace tflite