blob: b4e7804b1ee72365faeccbe385de3146ac9ddfd0 [file] [log] [blame]
/**
* Copyright 2017 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "run_tflite.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include <android/log.h>
#include <cstdio>
#include <sys/time.h>
#define LOG_TAG "NN_BENCHMARK"
BenchmarkModel::BenchmarkModel(const char* modelfile) {
// Memory map the model. NOTE this needs lifetime greater than or equal
// to interpreter context.
mTfliteModel = tflite::FlatBufferModel::BuildFromFile(modelfile);
if (!mTfliteModel) {
__android_log_print(ANDROID_LOG_ERROR, LOG_TAG,
"Failed to load model %s", modelfile);
return;
}
tflite::ops::builtin::BuiltinOpResolver resolver;
tflite::InterpreterBuilder(*mTfliteModel, resolver)(&mTfliteInterpreter);
if (!mTfliteInterpreter) {
__android_log_print(ANDROID_LOG_ERROR, LOG_TAG,
"Failed to create TFlite interpreter");
return;
}
}
BenchmarkModel::~BenchmarkModel() {
}
bool BenchmarkModel::setInput(const uint8_t* dataPtr, size_t length) {
int input = mTfliteInterpreter->inputs()[0];
auto* input_tensor = mTfliteInterpreter->tensor(input);
switch (input_tensor->type) {
case kTfLiteFloat32:
case kTfLiteUInt8: {
void* raw = mTfliteInterpreter->typed_tensor<void>(input);
memcpy(raw, dataPtr, length);
break;
}
default:
__android_log_print(ANDROID_LOG_ERROR, LOG_TAG,
"Input tensor type not supported");
return false;
}
return true;
}
bool BenchmarkModel::resizeInputTensors(std::vector<int> shape) {
// The benchmark only expects single input tensor, hardcoded as 0.
int input = mTfliteInterpreter->inputs()[0];
mTfliteInterpreter->ResizeInputTensor(input, shape);
if (mTfliteInterpreter->AllocateTensors() != kTfLiteOk) {
__android_log_print(ANDROID_LOG_ERROR, LOG_TAG, "Failed to allocate tensors!");
return false;
}
return true;
}
bool BenchmarkModel::runBenchmark(int num_inferences,
bool use_nnapi) {
mTfliteInterpreter->UseNNAPI(use_nnapi);
for(int i = 0; i < num_inferences; i++){
auto status = mTfliteInterpreter->Invoke();
if (status != kTfLiteOk) {
__android_log_print(ANDROID_LOG_ERROR, LOG_TAG, "Failed to invoke: %d!", (int)status);
return false;
}
}
return true;
}