| /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| #include "tensorflow/lite/c/c_api.h" |
| |
| #include <memory> |
| |
| #include "tensorflow/lite/builtin_ops.h" |
| #include "tensorflow/lite/c/c_api_internal.h" |
| #include "tensorflow/lite/delegates/interpreter_utils.h" |
| #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" |
| #include "tensorflow/lite/error_reporter.h" |
| #include "tensorflow/lite/interpreter.h" |
| #include "tensorflow/lite/kernels/internal/compatibility.h" |
| #include "tensorflow/lite/kernels/register.h" |
| #include "tensorflow/lite/model.h" |
| #include "tensorflow/lite/version.h" |
| |
| namespace { |
| class CallbackErrorReporter : public tflite::ErrorReporter { |
| public: |
| explicit CallbackErrorReporter(TfLiteErrorReporterCallback callback) |
| : callback_(callback) {} |
| |
| int Report(const char* format, va_list args) override { |
| callback_.error_reporter(callback_.user_data, format, args); |
| return 0; |
| } |
| |
| private: |
| TfLiteErrorReporterCallback callback_; |
| }; |
| |
| /// `CallbackOpResolver` is a (C++) `tflite::OpResolver` that forwards the |
| /// methods to (C ABI) callback functions from a `TfLiteOpResolverCallbacks` |
| /// struct. |
| /// |
| /// The SetCallbacks method must be called before calling any of the FindOp |
| /// methods. |
| class CallbackOpResolver : public ::tflite::OpResolver { |
| public: |
| CallbackOpResolver() {} |
| void SetCallbacks( |
| const struct TfLiteOpResolverCallbacks& op_resolver_callbacks) { |
| op_resolver_callbacks_ = op_resolver_callbacks; |
| } |
| const TfLiteRegistration* FindOp(tflite::BuiltinOperator op, |
| int version) const override { |
| if (op_resolver_callbacks_.find_builtin_op == nullptr) { |
| return nullptr; |
| } |
| return op_resolver_callbacks_.find_builtin_op( |
| op_resolver_callbacks_.user_data, |
| static_cast<TfLiteBuiltinOperator>(op), version); |
| } |
| const TfLiteRegistration* FindOp(const char* op, int version) const override { |
| if (op_resolver_callbacks_.find_custom_op == nullptr) { |
| return nullptr; |
| } |
| return op_resolver_callbacks_.find_custom_op( |
| op_resolver_callbacks_.user_data, op, version); |
| } |
| |
| private: |
| CallbackOpResolver(const CallbackOpResolver&) = delete; |
| CallbackOpResolver& operator=(const CallbackOpResolver&) = delete; |
| |
| struct TfLiteOpResolverCallbacks op_resolver_callbacks_ = {}; |
| }; |
| |
| } // namespace |
| |
| extern "C" { |
| |
| // LINT.IfChange |
| |
| const char* TfLiteVersion() { return TFLITE_VERSION_STRING; } |
| |
| TfLiteModel* TfLiteModelCreate(const void* model_data, size_t model_size) { |
| auto model = tflite::FlatBufferModel::VerifyAndBuildFromBuffer( |
| static_cast<const char*>(model_data), model_size); |
| std::shared_ptr<const tflite::FlatBufferModel> shared_model(model.release()); |
| return shared_model ? new TfLiteModel{std::move(shared_model)} : nullptr; |
| } |
| |
| TfLiteModel* TfLiteModelCreateFromFile(const char* model_path) { |
| auto model = tflite::FlatBufferModel::VerifyAndBuildFromFile(model_path); |
| std::shared_ptr<const tflite::FlatBufferModel> shared_model(model.release()); |
| return shared_model ? new TfLiteModel{std::move(shared_model)} : nullptr; |
| } |
| |
| void TfLiteModelDelete(TfLiteModel* model) { delete model; } |
| |
| TfLiteInterpreterOptions* TfLiteInterpreterOptionsCreate() { |
| return new TfLiteInterpreterOptions{}; |
| } |
| |
| void TfLiteInterpreterOptionsDelete(TfLiteInterpreterOptions* options) { |
| delete options; |
| } |
| |
| void TfLiteInterpreterOptionsSetNumThreads(TfLiteInterpreterOptions* options, |
| int32_t num_threads) { |
| options->num_threads = num_threads; |
| } |
| |
| void TfLiteInterpreterOptionsAddDelegate(TfLiteInterpreterOptions* options, |
| TfLiteDelegate* delegate) { |
| options->delegates.push_back(delegate); |
| } |
| |
| void TfLiteInterpreterOptionsSetErrorReporter( |
| TfLiteInterpreterOptions* options, |
| void (*reporter)(void* user_data, const char* format, va_list args), |
| void* user_data) { |
| options->error_reporter_callback.error_reporter = reporter; |
| options->error_reporter_callback.user_data = user_data; |
| } |
| |
| TfLiteInterpreter* TfLiteInterpreterCreate( |
| const TfLiteModel* model, |
| const TfLiteInterpreterOptions* optional_options) { |
| tflite::ops::builtin::BuiltinOpResolver resolver; |
| return tflite::internal::InterpreterCreateWithOpResolver( |
| model, optional_options, &resolver); |
| } |
| |
| void TfLiteInterpreterDelete(TfLiteInterpreter* interpreter) { |
| delete interpreter; |
| } |
| |
| int32_t TfLiteInterpreterGetInputTensorCount( |
| const TfLiteInterpreter* interpreter) { |
| return static_cast<int32_t>(interpreter->impl->inputs().size()); |
| } |
| |
| TfLiteTensor* TfLiteInterpreterGetInputTensor( |
| const TfLiteInterpreter* interpreter, int32_t input_index) { |
| return interpreter->impl->tensor(interpreter->impl->inputs()[input_index]); |
| } |
| |
| TfLiteStatus TfLiteInterpreterResizeInputTensor(TfLiteInterpreter* interpreter, |
| int32_t input_index, |
| const int* input_dims, |
| int32_t input_dims_size) { |
| std::vector<int> dims{input_dims, input_dims + input_dims_size}; |
| return interpreter->impl->ResizeInputTensor( |
| interpreter->impl->inputs()[input_index], dims); |
| } |
| |
| TfLiteStatus TfLiteInterpreterAllocateTensors(TfLiteInterpreter* interpreter) { |
| return interpreter->impl->AllocateTensors(); |
| } |
| |
| TfLiteStatus TfLiteInterpreterInvoke(TfLiteInterpreter* interpreter) { |
| if (interpreter->enable_delegate_fallback) { |
| return tflite::delegates::InterpreterUtils::InvokeWithCPUFallback( |
| interpreter->impl.get()); |
| } else { |
| return interpreter->impl->Invoke(); |
| } |
| } |
| |
| int32_t TfLiteInterpreterGetOutputTensorCount( |
| const TfLiteInterpreter* interpreter) { |
| return static_cast<int32_t>(interpreter->impl->outputs().size()); |
| } |
| |
| const TfLiteTensor* TfLiteInterpreterGetOutputTensor( |
| const TfLiteInterpreter* interpreter, int32_t output_index) { |
| return interpreter->impl->tensor(interpreter->impl->outputs()[output_index]); |
| } |
| |
| TfLiteType TfLiteTensorType(const TfLiteTensor* tensor) { return tensor->type; } |
| |
| int32_t TfLiteTensorNumDims(const TfLiteTensor* tensor) { |
| return tensor->dims->size; |
| } |
| |
| int32_t TfLiteTensorDim(const TfLiteTensor* tensor, int32_t dim_index) { |
| return tensor->dims->data[dim_index]; |
| } |
| |
| size_t TfLiteTensorByteSize(const TfLiteTensor* tensor) { |
| return tensor->bytes; |
| } |
| |
| void* TfLiteTensorData(const TfLiteTensor* tensor) { |
| return static_cast<void*>(tensor->data.raw); |
| } |
| |
| const char* TfLiteTensorName(const TfLiteTensor* tensor) { |
| return tensor->name; |
| } |
| |
| TfLiteQuantizationParams TfLiteTensorQuantizationParams( |
| const TfLiteTensor* tensor) { |
| return tensor->params; |
| } |
| |
| TfLiteStatus TfLiteTensorCopyFromBuffer(TfLiteTensor* tensor, |
| const void* input_data, |
| size_t input_data_size) { |
| if (tensor->bytes != input_data_size) { |
| return kTfLiteError; |
| } |
| memcpy(tensor->data.raw, input_data, input_data_size); |
| return kTfLiteOk; |
| } |
| |
| TfLiteStatus TfLiteTensorCopyToBuffer(const TfLiteTensor* tensor, |
| void* output_data, |
| size_t output_data_size) { |
| if (tensor->bytes != output_data_size) { |
| return kTfLiteError; |
| } |
| memcpy(output_data, tensor->data.raw, output_data_size); |
| return kTfLiteOk; |
| } |
| |
| // LINT.ThenChange(//tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs) |
| |
| } // extern "C" |
| |
| namespace tflite { |
| namespace internal { |
| |
| TfLiteInterpreter* InterpreterCreateWithOpResolver( |
| const TfLiteModel* model, const TfLiteInterpreterOptions* optional_options, |
| tflite::MutableOpResolver* mutable_resolver) { |
| TFLITE_DCHECK_NE(mutable_resolver, nullptr); |
| if (!model || !model->impl) { |
| return nullptr; |
| } |
| |
| std::unique_ptr<tflite::ErrorReporter> optional_error_reporter; |
| if (optional_options && |
| optional_options->error_reporter_callback.error_reporter != nullptr) { |
| optional_error_reporter.reset( |
| new CallbackErrorReporter(optional_options->error_reporter_callback)); |
| } |
| |
| // By default, we use the provided mutable_op_resolver, adding any builtin or |
| // custom ops registered with `TfLiteInterpreterOptionsAddBuiltinOp` and/or |
| // `TfLiteInterpreterOptionsAddCustomOp`. |
| tflite::OpResolver* op_resolver = mutable_resolver; |
| if (optional_options) { |
| mutable_resolver->AddAll(optional_options->mutable_op_resolver); |
| } |
| // However, if `TfLiteInterpreterOptionsSetOpResolver` has been called with |
| // a non-null callback parameter, then we instead use a |
| // `CallbackOpResolver` that will forward to the callbacks provided there. |
| CallbackOpResolver callback_op_resolver; |
| if (optional_options && |
| (optional_options->op_resolver_callbacks.find_builtin_op != nullptr || |
| optional_options->op_resolver_callbacks.find_custom_op != nullptr)) { |
| callback_op_resolver.SetCallbacks(optional_options->op_resolver_callbacks); |
| op_resolver = &callback_op_resolver; |
| } |
| |
| tflite::ErrorReporter* error_reporter = optional_error_reporter |
| ? optional_error_reporter.get() |
| : tflite::DefaultErrorReporter(); |
| tflite::InterpreterBuilder builder(model->impl->GetModel(), *op_resolver, |
| error_reporter); |
| |
| std::unique_ptr<tflite::Interpreter> interpreter; |
| if (builder(&interpreter) != kTfLiteOk) { |
| return nullptr; |
| } |
| |
| if (optional_options) { |
| if (optional_options->num_threads != |
| TfLiteInterpreterOptions::kDefaultNumThreads) { |
| interpreter->SetNumThreads(optional_options->num_threads); |
| } |
| |
| if (optional_options->use_nnapi) { |
| if (interpreter->ModifyGraphWithDelegate(tflite::NnApiDelegate()) != |
| kTfLiteOk) { |
| return nullptr; |
| } |
| } |
| |
| for (auto* delegate : optional_options->delegates) { |
| if (interpreter->ModifyGraphWithDelegate(delegate) != kTfLiteOk) { |
| return nullptr; |
| } |
| } |
| } |
| |
| bool enable_delegate_fallback = |
| optional_options != nullptr && optional_options->enable_delegate_fallback; |
| |
| return new TfLiteInterpreter{model->impl, std::move(optional_error_reporter), |
| std::move(interpreter), |
| enable_delegate_fallback}; |
| } |
| |
| } // namespace internal |
| } // namespace tflite |