| /** |
| * Copyright 2017 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "run_tflite.h" |
| |
| #include "tensorflow/contrib/lite/kernels/register.h" |
| |
| #include <android/log.h> |
| #include <cstdio> |
| #include <sys/time.h> |
| |
| #define LOG_TAG "NN_BENCHMARK" |
| |
| BenchmarkModel::BenchmarkModel(const char* modelfile) { |
| // Memory map the model. NOTE this needs lifetime greater than or equal |
| // to interpreter context. |
| mTfliteModel = tflite::FlatBufferModel::BuildFromFile(modelfile); |
| if (!mTfliteModel) { |
| __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, |
| "Failed to load model %s", modelfile); |
| return; |
| } |
| |
| tflite::ops::builtin::BuiltinOpResolver resolver; |
| tflite::InterpreterBuilder(*mTfliteModel, resolver)(&mTfliteInterpreter); |
| if (!mTfliteInterpreter) { |
| __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, |
| "Failed to create TFlite interpreter"); |
| return; |
| } |
| } |
| |
| BenchmarkModel::~BenchmarkModel() { |
| } |
| |
| bool BenchmarkModel::setInput(const uint8_t* dataPtr, size_t length) { |
| int input = mTfliteInterpreter->inputs()[0]; |
| auto* input_tensor = mTfliteInterpreter->tensor(input); |
| switch (input_tensor->type) { |
| case kTfLiteFloat32: |
| case kTfLiteUInt8: { |
| void* raw = mTfliteInterpreter->typed_tensor<void>(input); |
| memcpy(raw, dataPtr, length); |
| break; |
| } |
| default: |
| __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, |
| "Input tensor type not supported"); |
| return false; |
| } |
| return true; |
| } |
| |
| bool BenchmarkModel::resizeInputTensors(std::vector<int> shape) { |
| // The benchmark only expects single input tensor, hardcoded as 0. |
| int input = mTfliteInterpreter->inputs()[0]; |
| mTfliteInterpreter->ResizeInputTensor(input, shape); |
| if (mTfliteInterpreter->AllocateTensors() != kTfLiteOk) { |
| __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, "Failed to allocate tensors!"); |
| return false; |
| } |
| return true; |
| } |
| |
| bool BenchmarkModel::runBenchmark(int num_inferences, |
| bool use_nnapi) { |
| mTfliteInterpreter->UseNNAPI(use_nnapi); |
| |
| for(int i = 0; i < num_inferences; i++){ |
| auto status = mTfliteInterpreter->Invoke(); |
| if (status != kTfLiteOk) { |
| __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, "Failed to invoke: %d!", (int)status); |
| return false; |
| } |
| } |
| return true; |
| } |