blob: 6ef6c2ce1940e88a5488a29779bdacc8d3f28253 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/interpreter.h"
#include <cassert>
#include <cstdarg>
#include <cstdint>
#include <cstring>
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/context_util.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/graph_info.h"
#include "tensorflow/lite/memory_planner.h"
#include "tensorflow/lite/minimal_logging.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/util.h"
// TODO(b/132087118): move static_assert to c_api_internal when compiled with
// C++.
static_assert(sizeof(TfLiteFloat16) == sizeof(uint16_t),
"Float 16 type must be 16 bits.");
namespace tflite {
namespace {
// Gets the current TfLiteQuantization from the legacy fLiteQuantizationParams.
TfLiteQuantization GetQuantizationFromLegacy(
const TfLiteQuantizationParams& legacy_quantization) {
TfLiteQuantization quantization;
quantization.type = kTfLiteAffineQuantization;
auto* affine_quantization = reinterpret_cast<TfLiteAffineQuantization*>(
malloc(sizeof(TfLiteAffineQuantization)));
affine_quantization->scale = TfLiteFloatArrayCreate(1);
affine_quantization->zero_point = TfLiteIntArrayCreate(1);
affine_quantization->scale->data[0] = legacy_quantization.scale;
affine_quantization->zero_point->data[0] = legacy_quantization.zero_point;
quantization.params = affine_quantization;
return quantization;
}
} // namespace
Interpreter::Interpreter(ErrorReporter* error_reporter)
: error_reporter_(error_reporter ? error_reporter
: DefaultErrorReporter()) {
// TODO(b/128420794): Include the TFLite runtime version in the log.
TFLITE_LOG_PROD_ONCE(TFLITE_LOG_INFO, "Initialized TensorFlow Lite runtime.");
// There's always at least 1 subgraph which is the primary subgraph.
AddSubgraphs(1);
context_ = primary_subgraph().context();
// Reserve some space for the tensors to avoid excessive resizing.
for (int i = 0; i < kTfLiteMaxExternalContexts; ++i) {
external_contexts_[i] = nullptr;
}
// This operation is cheap because we allocate the CPU context resources (i.e.
// threads) lazily.
own_external_cpu_backend_context_.reset(new ExternalCpuBackendContext());
external_contexts_[kTfLiteCpuBackendContext] =
own_external_cpu_backend_context_.get();
UseNNAPI(false);
}
Interpreter::~Interpreter() {}
void Interpreter::SetExternalContext(TfLiteExternalContextType type,
TfLiteExternalContext* ctx) {
if (ctx == own_external_cpu_backend_context_.get()) {
error_reporter_->Report(
"WARNING: The passed external context is identical to the internally "
"owned one.");
return;
}
// We have an internally owned external context of kTfLiteCpuBackendContext.
// If it's overwritten here, we will release the resource of the internally
// owned external context.
// Note: the 'max thread count' info associated with the overwritten context
// will be lost here, and such info is now detemined by the new context, thus
// affecting how much parallelism a TFLite op would have.
if (kTfLiteCpuBackendContext == type &&
external_contexts_[kTfLiteCpuBackendContext] ==
own_external_cpu_backend_context_.get()) {
own_external_cpu_backend_context_.reset();
}
// This essentially changes the "external_contexts_[type]".
primary_subgraph().SetExternalContext(type, ctx);
}
TfLiteStatus Interpreter::SetInputs(std::vector<int> inputs) {
return primary_subgraph().SetInputs(inputs);
}
TfLiteStatus Interpreter::SetOutputs(std::vector<int> outputs) {
return primary_subgraph().SetOutputs(outputs);
}
TfLiteStatus Interpreter::SetVariables(std::vector<int> variables) {
return primary_subgraph().SetVariables(variables);
}
TfLiteStatus Interpreter::AllocateTensors() {
return primary_subgraph().AllocateTensors();
}
void Interpreter::ReserveNodes(int count) {
primary_subgraph().ReserveNodes(count);
}
void Interpreter::AddSubgraphs(int subgraphs_to_add,
int* first_new_subgraph_index) {
const size_t base_index = subgraphs_.size();
if (first_new_subgraph_index) *first_new_subgraph_index = base_index;
subgraphs_.reserve(base_index + subgraphs_to_add);
for (int i = 0; i < subgraphs_to_add; ++i) {
Subgraph* subgraph = new Subgraph(error_reporter_, external_contexts_,
&subgraphs_, &resource_variables_);
subgraphs_.emplace_back(subgraph);
}
}
TfLiteStatus Interpreter::AddNodeWithParameters(
const std::vector<int>& inputs, const std::vector<int>& outputs,
const char* init_data, size_t init_data_size, void* builtin_data,
const TfLiteRegistration* registration, int* node_index) {
return primary_subgraph().AddNodeWithParameters(
inputs, outputs, {}, init_data, init_data_size, builtin_data,
registration, node_index);
}
TfLiteStatus Interpreter::ResizeInputTensor(int tensor_index,
const std::vector<int>& dims) {
return primary_subgraph().ResizeInputTensor(tensor_index, dims);
}
TfLiteStatus Interpreter::Invoke() {
TF_LITE_ENSURE_STATUS(primary_subgraph().Invoke());
if (!allow_buffer_handle_output_) {
for (int tensor_index : outputs()) {
TF_LITE_ENSURE_STATUS(
primary_subgraph().EnsureTensorDataIsReadable(tensor_index));
}
}
return kTfLiteOk;
}
TfLiteStatus Interpreter::AddTensors(int tensors_to_add,
int* first_new_tensor_index) {
return primary_subgraph().AddTensors(tensors_to_add, first_new_tensor_index);
}
TfLiteStatus Interpreter::ResetVariableTensors() {
return primary_subgraph().ResetVariableTensors();
}
TfLiteStatus Interpreter::SetTensorParametersReadOnly(
int tensor_index, TfLiteType type, const char* name,
const std::vector<int>& dims, TfLiteQuantization quantization,
const char* buffer, size_t bytes, const Allocation* allocation) {
return primary_subgraph().SetTensorParametersReadOnly(
tensor_index, type, name, dims.size(), dims.data(), quantization, buffer,
bytes, allocation);
}
TfLiteStatus Interpreter::SetTensorParametersReadWrite(
int tensor_index, TfLiteType type, const char* name,
const std::vector<int>& dims, TfLiteQuantization quantization,
bool is_variable) {
return primary_subgraph().SetTensorParametersReadWrite(
tensor_index, type, name, dims.size(), dims.data(), quantization,
is_variable);
}
TfLiteStatus Interpreter::SetTensorParametersReadOnly(
int tensor_index, TfLiteType type, const char* name, const size_t rank,
const int* dims, TfLiteQuantizationParams quantization, const char* buffer,
size_t bytes, const Allocation* allocation) {
TfLiteQuantization new_quantization = GetQuantizationFromLegacy(quantization);
return primary_subgraph().SetTensorParametersReadOnly(
tensor_index, type, name, rank, dims, new_quantization, buffer, bytes,
allocation);
}
TfLiteStatus Interpreter::SetTensorParametersReadWrite(
int tensor_index, TfLiteType type, const char* name, const size_t rank,
const int* dims, TfLiteQuantizationParams quantization, bool is_variable) {
TfLiteQuantization new_quantization = GetQuantizationFromLegacy(quantization);
return primary_subgraph().SetTensorParametersReadWrite(
tensor_index, type, name, rank, dims, new_quantization, is_variable);
}
TfLiteStatus Interpreter::SetExecutionPlan(const std::vector<int>& new_plan) {
return primary_subgraph().SetExecutionPlan(new_plan);
}
void Interpreter::UseNNAPI(bool enable) { primary_subgraph().UseNNAPI(enable); }
void Interpreter::SetNumThreads(int num_threads) {
for (auto& subgraph : subgraphs_) {
subgraph->context()->recommended_num_threads = num_threads;
}
for (int i = 0; i < kTfLiteMaxExternalContexts; ++i) {
auto* c = external_contexts_[i];
if (c && c->Refresh) {
c->Refresh(context_);
}
}
}
void Interpreter::SetAllowFp16PrecisionForFp32(bool allow) {
for (auto& subgraph : subgraphs_) {
subgraph->context()->allow_fp32_relax_to_fp16 = allow;
}
}
// TODO(b/121264966): Subgraphs added after cancellation is set will not get the
// cancellation function added to their context.
void Interpreter::SetCancellationFunction(void* data,
bool (*check_cancelled_func)(void*)) {
for (auto& subgraph : subgraphs_) {
subgraph->SetCancellationFunction(data, check_cancelled_func);
}
}
TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) {
for (auto& subgraph : subgraphs_) {
TF_LITE_ENSURE_OK(context_, subgraph->ModifyGraphWithDelegate(delegate));
}
return kTfLiteOk;
}
TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegatePtr delegate) {
// Note that we retain ownership of the delegate even if graph modification
// fails, as delegate use will be in an indeterminate state at that point.
owned_delegates_.push_back(std::move(delegate));
return ModifyGraphWithDelegate(owned_delegates_.back().get());
}
TfLiteStatus Interpreter::SetBufferHandle(int tensor_index,
TfLiteBufferHandle buffer_handle,
TfLiteDelegate* delegate) {
TF_LITE_ENSURE(context_, tensor_index < tensors_size());
std::vector<TfLiteTensor>& tensors = primary_subgraph().tensors();
TfLiteTensor* tensor = &tensors[tensor_index];
TF_LITE_ENSURE(context_,
tensor->delegate == nullptr || tensor->delegate == delegate);
tensor->delegate = delegate;
if (tensor->buffer_handle != kTfLiteNullBufferHandle) {
TF_LITE_ENSURE(context_, tensor->delegate->FreeBufferHandle != nullptr);
tensor->delegate->FreeBufferHandle(context_, tensor->delegate,
&tensor->buffer_handle);
}
tensor->buffer_handle = buffer_handle;
return kTfLiteOk;
}
TfLiteStatus Interpreter::GetBufferHandle(int tensor_index,
TfLiteBufferHandle* buffer_handle,
TfLiteDelegate** delegate) {
TF_LITE_ENSURE(context_, tensor_index < tensors_size());
std::vector<TfLiteTensor>& tensors = primary_subgraph().tensors();
TfLiteTensor* tensor = &tensors[tensor_index];
*delegate = tensor->delegate;
*buffer_handle = tensor->buffer_handle;
return kTfLiteOk;
}
void Interpreter::SetProfiler(Profiler* profiler) {
for (auto& subgraph : subgraphs_) subgraph->SetProfiler(profiler);
}
Profiler* Interpreter::GetProfiler() {
return primary_subgraph().GetProfiler();
}
} // namespace tflite