blob: 516ba6937380c0324feb0cbd4e642df59d63a34f [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/model.h"
#include <fcntl.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/stat.h>
#include <sys/types.h>
#include "tensorflow/lite/allocation.h"
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/core/api/flatbuffer_conversions.h"
#include "tensorflow/lite/util.h"
#include "tensorflow/lite/version.h"
namespace tflite {
namespace {
// Ensure that ErrorReporter is non-null.
ErrorReporter* ValidateErrorReporter(ErrorReporter* e) {
return e ? e : DefaultErrorReporter();
}
} // namespace
const char* kEmptyTensorName = "";
// Normally we'd use ABSL_HAVE_ATTRIBUTE_WEAK and ABSL_ATTRIBUTE_WEAK, but
// we avoid the absl dependency for binary size reasons.
#ifdef __has_attribute
#define TFLITE_HAS_ATTRIBUTE(x) __has_attribute(x)
#else
#define TFLITE_HAS_ATTRIBUTE(x) 0
#endif
#if TFLITE_HAS_ATTRIBUTE(weak) || (defined(__GNUC__) && !defined(__clang__))
// Using weak symbols for the flex delegate allows automatic injection of the
// delegate simply by adding it as a dependency. See also the strong override in
// lite/delegates/flex/delegate.cc.
__attribute__((weak)) Interpreter::TfLiteDelegatePtr AcquireFlexDelegate() {
return Interpreter::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {});
}
#else
Interpreter::TfLiteDelegatePtr (*AcquireFlexDelegate)() = nullptr;
#endif
#ifndef TFLITE_MCU
// Loads a model from `filename`. If `mmap_file` is true then use mmap,
// otherwise make a copy of the model in a buffer.
std::unique_ptr<Allocation> GetAllocationFromFile(const char* filename,
bool mmap_file,
ErrorReporter* error_reporter,
bool use_nnapi) {
std::unique_ptr<Allocation> allocation;
if (mmap_file && MMAPAllocation::IsSupported()) {
allocation.reset(new MMAPAllocation(filename, error_reporter));
} else {
allocation.reset(new FileCopyAllocation(filename, error_reporter));
}
return allocation;
}
std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFile(
const char* filename, ErrorReporter* error_reporter) {
error_reporter = ValidateErrorReporter(error_reporter);
std::unique_ptr<FlatBufferModel> model;
auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true,
error_reporter, /*use_nnapi=*/true);
model.reset(new FlatBufferModel(std::move(allocation), error_reporter));
if (!model->initialized()) model.reset();
return model;
}
std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromFile(
const char* filename, TfLiteVerifier* extra_verifier,
ErrorReporter* error_reporter) {
error_reporter = ValidateErrorReporter(error_reporter);
std::unique_ptr<FlatBufferModel> model;
auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true,
error_reporter, /*use_nnapi=*/true);
flatbuffers::Verifier base_verifier(
reinterpret_cast<const uint8_t*>(allocation->base()),
allocation->bytes());
if (!VerifyModelBuffer(base_verifier)) {
error_reporter->Report("The model is not a valid Flatbuffer file");
return nullptr;
}
if (extra_verifier &&
!extra_verifier->Verify(static_cast<const char*>(allocation->base()),
allocation->bytes(), error_reporter)) {
return model;
}
model.reset(new FlatBufferModel(std::move(allocation), error_reporter));
if (!model->initialized()) model.reset();
return model;
}
#endif
std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromBuffer(
const char* caller_owned_buffer, size_t buffer_size,
ErrorReporter* error_reporter) {
error_reporter = ValidateErrorReporter(error_reporter);
std::unique_ptr<FlatBufferModel> model;
std::unique_ptr<Allocation> allocation(
new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter));
model.reset(new FlatBufferModel(std::move(allocation), error_reporter));
if (!model->initialized()) model.reset();
return model;
}
std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromBuffer(
const char* buffer, size_t buffer_size, TfLiteVerifier* extra_verifier,
ErrorReporter* error_reporter) {
error_reporter = ValidateErrorReporter(error_reporter);
flatbuffers::Verifier base_verifier(reinterpret_cast<const uint8_t*>(buffer),
buffer_size);
if (!VerifyModelBuffer(base_verifier)) {
error_reporter->Report("The model is not a valid Flatbuffer buffer");
return nullptr;
}
if (extra_verifier &&
!extra_verifier->Verify(buffer, buffer_size, error_reporter)) {
return nullptr;
}
return BuildFromBuffer(buffer, buffer_size, error_reporter);
}
std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromModel(
const tflite::Model* caller_owned_model_spec,
ErrorReporter* error_reporter) {
error_reporter = ValidateErrorReporter(error_reporter);
std::unique_ptr<FlatBufferModel> model;
model.reset(new FlatBufferModel(caller_owned_model_spec, error_reporter));
if (!model->initialized()) model.reset();
return model;
}
string FlatBufferModel::GetMinimumRuntime() const {
if (!model_ || !model_->metadata()) return "";
for (int i = 0; i < model_->metadata()->size(); ++i) {
auto metadata = model_->metadata()->Get(i);
if (metadata->name()->str() == "min_runtime_version") {
auto buf = metadata->buffer();
auto* buffer = (*model_->buffers())[buf];
auto* array = buffer->data();
return string(reinterpret_cast<const char*>(array->data()),
array->size());
}
}
return "";
}
bool FlatBufferModel::CheckModelIdentifier() const {
if (!tflite::ModelBufferHasIdentifier(allocation_->base())) {
const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base());
error_reporter_->Report(
"Model provided has model identifier '%c%c%c%c', should be '%s'\n",
ident[0], ident[1], ident[2], ident[3], tflite::ModelIdentifier());
return false;
}
return true;
}
FlatBufferModel::FlatBufferModel(const Model* model,
ErrorReporter* error_reporter)
: model_(model), error_reporter_(ValidateErrorReporter(error_reporter)) {}
FlatBufferModel::FlatBufferModel(std::unique_ptr<Allocation> allocation,
ErrorReporter* error_reporter)
: error_reporter_(ValidateErrorReporter(error_reporter)),
allocation_(std::move(allocation)) {
if (!allocation_->valid() || !CheckModelIdentifier()) return;
model_ = ::tflite::GetModel(allocation_->base());
}
FlatBufferModel::~FlatBufferModel() {}
InterpreterBuilder::InterpreterBuilder(const FlatBufferModel& model,
const OpResolver& op_resolver)
: model_(model.GetModel()),
op_resolver_(op_resolver),
error_reporter_(ValidateErrorReporter(model.error_reporter())),
allocation_(model.allocation()) {}
InterpreterBuilder::InterpreterBuilder(const ::tflite::Model* model,
const OpResolver& op_resolver,
ErrorReporter* error_reporter)
: model_(model),
op_resolver_(op_resolver),
error_reporter_(ValidateErrorReporter(error_reporter)) {}
InterpreterBuilder::~InterpreterBuilder() {}
TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() {
TfLiteStatus status = kTfLiteOk;
// Reset state.
flatbuffer_op_index_to_registration_.clear();
unresolved_custom_ops_.clear();
auto opcodes = model_->operator_codes();
if (!opcodes) {
return status;
}
int num_custom_ops = 0;
for (const OperatorCode* opcode : *opcodes) {
if (opcode->builtin_code() == BuiltinOperator_CUSTOM) {
num_custom_ops++;
}
}
unresolved_custom_ops_.reserve(num_custom_ops);
for (const OperatorCode* opcode : *opcodes) {
const TfLiteRegistration* registration = nullptr;
status = GetRegistrationFromOpCode(opcode, op_resolver_, error_reporter_,
&registration);
if (status != kTfLiteOk) {
if (opcode->builtin_code() != BuiltinOperator_CUSTOM) {
return status;
}
// If it's an unresolved custom op, allow it for now. It might be resolved
// by a delegate later.
if (!opcode->custom_code()) {
error_reporter_->Report(
"Operator with CUSTOM builtin_code has no custom_code.\n");
return status;
}
const auto* op_name = opcode->custom_code()->c_str();
TfLiteRegistration unresolved_op{nullptr,
nullptr,
nullptr,
/*invoke*/ &UnresolvedOpInvoke,
nullptr,
BuiltinOperator_CUSTOM,
op_name,
1};
unresolved_custom_ops_.push_back(unresolved_op);
registration = &unresolved_custom_ops_.back();
has_flex_op_ |= IsFlexOp(op_name);
status = kTfLiteOk;
}
flatbuffer_op_index_to_registration_.push_back(registration);
}
return status;
}
namespace {
template <class T>
std::vector<int> FlatBufferIntArrayToVector(T* flat_array) {
// Initialize shape of tensors with null shape. Empty vectors are converted
// to nullptr for models that are constructed via flatbuffers::Pack.
if (flat_array == nullptr) {
return {};
}
std::vector<int> ret(flat_array->Length());
for (int i = 0; i < flat_array->Length(); i++) {
ret[i] = flat_array->Get(i);
}
return ret;
}
// Used to determine how the op data parsing function creates its working space.
class MallocDataAllocator : public BuiltinDataAllocator {
public:
void* Allocate(size_t size) override { return malloc(size); }
void Deallocate(void* data) override { free(data); }
};
} // namespace
TfLiteStatus InterpreterBuilder::ParseNodes(
const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators,
Subgraph* subgraph) {
TfLiteStatus status = kTfLiteOk;
// Reduce the number of redundant allocations
subgraph->ReserveNodes(operators->Length());
for (int i = 0; i < operators->Length(); ++i) {
const auto* op = operators->Get(i);
int index = op->opcode_index();
if (index < 0 || index >= flatbuffer_op_index_to_registration_.size()) {
error_reporter_->Report("Missing registration for opcode_index %d\n",
index);
status = kTfLiteError;
continue;
}
const TfLiteRegistration* registration =
flatbuffer_op_index_to_registration_[index];
if (registration == nullptr) {
error_reporter_->Report("Skipping op for opcode_index %d\n", index);
status = kTfLiteError;
continue;
}
BuiltinOperator op_type =
static_cast<BuiltinOperator>(registration->builtin_code);
if (op_type != BuiltinOperator_CUSTOM && op->custom_options()) {
error_reporter_->Report(
"Found builtin operator %s with custom options.\n",
EnumNameBuiltinOperator(op_type));
}
if (op_type == BuiltinOperator_CUSTOM) {
if (op->custom_options()) {
subgraph->AddNodeWithParameters(
FlatBufferIntArrayToVector(op->inputs()),
FlatBufferIntArrayToVector(op->outputs()),
FlatBufferIntArrayToVector(op->intermediates()),
reinterpret_cast<const char*>(op->custom_options()->data()),
op->custom_options()->size(), nullptr, registration);
} else {
subgraph->AddNodeWithParameters(
FlatBufferIntArrayToVector(op->inputs()),
FlatBufferIntArrayToVector(op->outputs()),
FlatBufferIntArrayToVector(op->intermediates()), nullptr, 0,
nullptr, registration);
}
} else {
void* builtin_data = nullptr;
MallocDataAllocator malloc_allocator;
TF_LITE_ENSURE_STATUS(ParseOpData(op, op_type, error_reporter_,
&malloc_allocator, &builtin_data));
subgraph->AddNodeWithParameters(
FlatBufferIntArrayToVector(op->inputs()),
FlatBufferIntArrayToVector(op->outputs()),
FlatBufferIntArrayToVector(op->intermediates()), nullptr, 0,
builtin_data, registration);
}
}
return status;
}
TfLiteStatus InterpreterBuilder::ParseQuantization(
const QuantizationParameters* src_quantization,
TfLiteQuantization* quantization, const std::vector<int>& dims) {
quantization->type = kTfLiteNoQuantization;
if (!src_quantization || !src_quantization->scale() ||
src_quantization->scale()->size() == 0) {
return kTfLiteOk;
}
if (!src_quantization->zero_point()) {
error_reporter_->Report(
"Quantization parameters has non-null scale but null zero_point.");
return kTfLiteError;
}
// Ensure that the number of scales matches the number of zero_points.
if (src_quantization->scale()->size() !=
src_quantization->zero_point()->size()) {
error_reporter_->Report(
"QuantizationParam has %d zero_point values and %d scale values. Must "
"have same number.",
src_quantization->zero_point()->size(),
src_quantization->scale()->size());
return kTfLiteError;
}
// Affine-quantization.
quantization->type = kTfLiteAffineQuantization;
const size_t num_scales = src_quantization->scale()->size();
// Ensure that the quantization dimension is valid.
if (src_quantization->quantized_dimension() < 0 ||
(!dims.empty() &&
src_quantization->quantized_dimension() >= dims.size())) {
error_reporter_->Report(
"quantized_dimension must be in range [0, %d). Was %d.", dims.size(),
src_quantization->quantized_dimension());
return kTfLiteError;
}
// Ensure that the number of scales is 1 for per-layer quantization, and
// matches number of quantization dimensions for per-axis quantization.
if (num_scales != 1 &&
(!dims.empty() &&
num_scales != dims[src_quantization->quantized_dimension()])) {
error_reporter_->Report(
"num_scales must be 1 for per-layer quantization, or %d for per-axis "
"quantization, but got %d.",
dims[src_quantization->quantized_dimension()], num_scales);
return kTfLiteError;
}
auto* affine_quantization = reinterpret_cast<TfLiteAffineQuantization*>(
malloc(sizeof(TfLiteAffineQuantization)));
affine_quantization->scale = TfLiteFloatArrayCreate(num_scales);
affine_quantization->zero_point = TfLiteIntArrayCreate(num_scales);
for (size_t i = 0; i < num_scales; ++i) {
affine_quantization->scale->data[i] = src_quantization->scale()->Get(i);
affine_quantization->zero_point->data[i] =
src_quantization->zero_point()->Get(i);
}
affine_quantization->quantized_dimension =
src_quantization->quantized_dimension();
quantization->params = reinterpret_cast<void*>(affine_quantization);
return kTfLiteOk;
}
TfLiteStatus InterpreterBuilder::ParseTensors(
const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
Subgraph* subgraph) {
TfLiteStatus status = kTfLiteOk;
// A little helper to get the names of inputs and outputs. Note that they
// must outlive the subgraph.
auto get_name = [](const tflite::Tensor* t) -> const char* {
auto name = t->name();
if (name) return name->c_str();
return kEmptyTensorName;
};
for (int i = 0; i < tensors->Length(); ++i) {
const auto* tensor = tensors->Get(i);
std::vector<int> dims = FlatBufferIntArrayToVector(tensor->shape());
TfLiteType type;
if (ConvertTensorType(tensor->type(), &type, error_reporter_) !=
kTfLiteOk) {
status = kTfLiteError;
continue;
}
auto get_readonly_data = [&](const char** buffer_data,
size_t* buffer_size) {
// TODO(aselle): Check what happens if we have an unspecified size
// constant.
*buffer_data = nullptr;
if (tensor->buffer() == 0) return kTfLiteOk;
if (tensor->buffer() >= buffers->size()) {
error_reporter_->Report(
"Tensor %d specifies out of range buffer %d (only %d buffers).\n",
i, tensor->buffer(), buffers->size());
return kTfLiteError;
}
if (auto* buffer = (*buffers)[tensor->buffer()]) {
if (auto* array = buffer->data()) {
if (size_t size = array->size()) {
*buffer_size = size;
*buffer_data = reinterpret_cast<const char*>(array->data());
return kTfLiteOk;
}
}
}
return kTfLiteOk;
};
size_t buffer_size = 0;
const char* buffer_ptr;
TF_LITE_ENSURE_STATUS(get_readonly_data(&buffer_ptr, &buffer_size));
const auto* src_quantization = tensor->quantization();
TfLiteQuantization quantization;
if (ParseQuantization(src_quantization, &quantization, dims) != kTfLiteOk) {
status = kTfLiteError;
continue;
}
bool is_variable = tensor->is_variable();
if (buffer_ptr) {
if (is_variable) {
error_reporter_->Report(
"Tensor %d is a variable tensor with buffer. "
"It's not supported now.\n",
i);
status = kTfLiteError;
}
if (subgraph->SetTensorParametersReadOnly(
i, type, get_name(tensor), dims, quantization, buffer_ptr,
buffer_size, allocation_) != kTfLiteOk) {
error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
i);
status = kTfLiteError;
}
} else {
if (subgraph->SetTensorParametersReadWrite(i, type, get_name(tensor),
dims, quantization,
is_variable) != kTfLiteOk) {
error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
i);
status = kTfLiteError;
}
}
}
return status;
}
TfLiteStatus InterpreterBuilder::ApplyDelegates(Interpreter* interpreter) {
// Apply Flex delegate if applicable.
if (!has_flex_op_ || AcquireFlexDelegate == nullptr) {
return kTfLiteOk;
} else if (auto flex_delegate = AcquireFlexDelegate()) {
return interpreter->ModifyGraphWithDelegate(std::move(flex_delegate));
}
return kTfLiteOk;
}
TfLiteStatus InterpreterBuilder::operator()(
std::unique_ptr<Interpreter>* interpreter) {
return operator()(interpreter, /*num_threads=*/-1);
}
TfLiteStatus InterpreterBuilder::operator()(
std::unique_ptr<Interpreter>* interpreter, int num_threads) {
if (!interpreter) {
error_reporter_->Report(
"Null output pointer passed to InterpreterBuilder.");
return kTfLiteError;
}
// Safe exit by deleting partially created interpreter, to reduce verbosity
// on error conditions. Use by return cleanup_on_error();
auto cleanup_and_error = [&interpreter]() {
interpreter->reset();
return kTfLiteError;
};
if (!model_) {
error_reporter_->Report("Null pointer passed in as model.");
return cleanup_and_error();
}
if (model_->version() != TFLITE_SCHEMA_VERSION) {
error_reporter_->Report(
"Model provided is schema version %d not equal "
"to supported version %d.\n",
model_->version(), TFLITE_SCHEMA_VERSION);
return cleanup_and_error();
}
if (BuildLocalIndexToRegistrationMapping() != kTfLiteOk) {
error_reporter_->Report("Registration failed.\n");
return cleanup_and_error();
}
// Flatbuffer model schemas define a list of opcodes independent of the graph.
// We first map those to registrations. This reduces string lookups for custom
// ops since we only do it once per custom op rather than once per custom op
// invocation in the model graph.
// Construct interpreter with correct number of tensors and operators.
auto* subgraphs = model_->subgraphs();
auto* buffers = model_->buffers();
if (subgraphs->size() == 0) {
error_reporter_->Report("No subgraph in the model.\n");
return cleanup_and_error();
}
interpreter->reset(new Interpreter(error_reporter_));
(*interpreter)->SetNumThreads(num_threads);
if (subgraphs->Length() > 1) {
(*interpreter)->AddSubgraphs(subgraphs->Length() - 1);
}
for (int subgraph_index = 0; subgraph_index < subgraphs->Length();
++subgraph_index) {
const tflite::SubGraph* subgraph = (*subgraphs)[subgraph_index];
tflite::Subgraph* modified_subgraph =
(*interpreter)->subgraph(subgraph_index);
auto operators = subgraph->operators();
auto tensors = subgraph->tensors();
if (!operators || !tensors || !buffers) {
error_reporter_->Report(
"Did not get operators, tensors, or buffers in subgraph %d.\n",
subgraph_index);
return cleanup_and_error();
}
if (modified_subgraph->AddTensors(tensors->Length()) != kTfLiteOk) {
return cleanup_and_error();
}
// Set num threads
// Parse inputs/outputs
modified_subgraph->SetInputs(
FlatBufferIntArrayToVector(subgraph->inputs()));
modified_subgraph->SetOutputs(
FlatBufferIntArrayToVector(subgraph->outputs()));
// Finally setup nodes and tensors
if (ParseNodes(operators, modified_subgraph) != kTfLiteOk)
return cleanup_and_error();
if (ParseTensors(buffers, tensors, modified_subgraph) != kTfLiteOk)
return cleanup_and_error();
std::vector<int> variables;
for (int i = 0; i < modified_subgraph->tensors_size(); ++i) {
auto* tensor = modified_subgraph->tensor(i);
if (tensor->is_variable) {
variables.push_back(i);
}
}
modified_subgraph->SetVariables(std::move(variables));
}
if (ApplyDelegates(interpreter->get()) != kTfLiteOk)
return cleanup_and_error();
return kTfLiteOk;
}
} // namespace tflite