Merge internal changes manually am: 08374c5630
Original change: https://android-review.googlesource.com/c/platform/test/mlts/benchmark/+/1449198
Change-Id: Iff96e0f28d994b82f21706f445d321318a58cb03
diff --git a/dogfood/src/com/android/nn/dogfood/BenchmarkJobService.java b/dogfood/src/com/android/nn/dogfood/BenchmarkJobService.java
index 0ce2fe8..3811222 100644
--- a/dogfood/src/com/android/nn/dogfood/BenchmarkJobService.java
+++ b/dogfood/src/com/android/nn/dogfood/BenchmarkJobService.java
@@ -30,6 +30,7 @@
import com.android.nn.benchmark.core.BenchmarkResult;
import com.android.nn.benchmark.core.Processor;
import com.android.nn.benchmark.core.TestModels;
+import com.android.nn.benchmark.core.TfLiteBackend;
import java.util.List;
import java.util.Random;
@@ -77,7 +78,7 @@
public void doBenchmark() {
mProcessor = new Processor(this, this, randomModelList());
- mProcessor.setUseNNApi(true);
+ mProcessor.setTfLiteBackend(TfLiteBackend.NNAPI);
mProcessor.setToggleLong(true);
mProcessor.setMaxRunIterations(1);
processorRunner.submit(mProcessor);
diff --git a/jni/benchmark_jni.cpp b/jni/benchmark_jni.cpp
index a1f8097..21bf2d4 100644
--- a/jni/benchmark_jni.cpp
+++ b/jni/benchmark_jni.cpp
@@ -28,7 +28,35 @@
#include <android/log.h>
#include <android/sharedmem.h>
#include <sys/mman.h>
+#include "tensorflow/lite/nnapi/nnapi_implementation.h"
+extern "C" JNIEXPORT jboolean JNICALL
+Java_com_android_nn_benchmark_core_NNTestBase_hasNnApiDevice(
+ JNIEnv *env, jobject /* this */, jstring _nnApiDeviceName) {
+ bool foundDevice = false;
+ const char *nnApiDeviceName =
+ _nnApiDeviceName == NULL ? NULL
+ : env->GetStringUTFChars(_nnApiDeviceName, NULL);
+ if (nnApiDeviceName != NULL) {
+ std::string device_name(nnApiDeviceName);
+ uint32_t numDevices = 0;
+ NnApiImplementation()->ANeuralNetworks_getDeviceCount(&numDevices);
+
+ for (uint32_t i = 0; i < numDevices; i++) {
+ ANeuralNetworksDevice *device = nullptr;
+ const char *buffer = nullptr;
+ NnApiImplementation()->ANeuralNetworks_getDevice(i, &device);
+ NnApiImplementation()->ANeuralNetworksDevice_getName(device, &buffer);
+ if (device_name == buffer) {
+ foundDevice = true;
+ break;
+ }
+ }
+ env->ReleaseStringUTFChars(_nnApiDeviceName, nnApiDeviceName);
+ }
+
+ return foundDevice;
+}
extern "C"
JNIEXPORT jlong
@@ -37,7 +65,7 @@
JNIEnv *env,
jobject /* this */,
jstring _modelFileName,
- jboolean _useNnApi,
+ jint _tfliteBackend,
jboolean _enableIntermediateTensorsDump,
jstring _nnApiDeviceName,
jboolean _mmapModel,
@@ -53,14 +81,14 @@
: env->GetStringUTFChars(_nnApiCacheDir, NULL);
int nnapiErrno = 0;
void *handle = BenchmarkModel::create(
- modelFileName, _useNnApi, _enableIntermediateTensorsDump, &nnapiErrno,
+ modelFileName, _tfliteBackend, _enableIntermediateTensorsDump, &nnapiErrno,
nnApiDeviceName, _mmapModel, nnApiCacheDir);
env->ReleaseStringUTFChars(_modelFileName, modelFileName);
if (_nnApiDeviceName != NULL) {
env->ReleaseStringUTFChars(_nnApiDeviceName, nnApiDeviceName);
}
- if (_useNnApi && nnapiErrno != 0) {
+ if (_tfliteBackend == TFLITE_NNAPI && nnapiErrno != 0) {
jclass nnapiFailureClass = env->FindClass(
"com/android/nn/benchmark/core/NnApiDelegationFailure");
jmethodID constructor =
@@ -197,7 +225,6 @@
jobject creator = env->GetObjectField(inout, inout_inputCreator);
if (creator == nullptr) { return false; }
env->CallVoidMethod(creator, createInput_method, byteBuffer);
- env->DeleteLocalRef(byteBuffer);
if (env->ExceptionCheck()) { return false; }
return true;
};
diff --git a/jni/run_tflite.cpp b/jni/run_tflite.cpp
index 8362af0..dfc3c82 100644
--- a/jni/run_tflite.cpp
+++ b/jni/run_tflite.cpp
@@ -28,9 +28,10 @@
#include <fstream>
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
-#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/nnapi/NeuralNetworksTypes.h"
+#include "tensorflow/lite/kernels/register.h"
+
#define LOG_TAG "NN_BENCHMARK"
#define FATAL(fmt, ...) \
@@ -69,12 +70,12 @@
} // namespace
-BenchmarkModel* BenchmarkModel::create(const char* modelfile, bool use_nnapi,
+BenchmarkModel* BenchmarkModel::create(const char* modelfile, int tfliteBackend,
bool enable_intermediate_tensors_dump, int* nnapiErrno,
const char* nnapi_device_name, bool mmapModel,
const char* nnapi_cache_dir) {
BenchmarkModel* model = new BenchmarkModel();
- if (!model->init(modelfile, use_nnapi, enable_intermediate_tensors_dump, nnapiErrno,
+ if (!model->init(modelfile, tfliteBackend, enable_intermediate_tensors_dump, nnapiErrno,
nnapi_device_name, mmapModel, nnapi_cache_dir)) {
__android_log_print(ANDROID_LOG_ERROR, LOG_TAG, "Failed to init model %s", modelfile);
delete model;
@@ -83,12 +84,13 @@
return model;
}
-bool BenchmarkModel::init(const char* modelfile, bool use_nnapi,
+bool BenchmarkModel::init(const char* modelfile, int tfliteBackend,
bool enable_intermediate_tensors_dump, int* nnapiErrno,
const char* nnapi_device_name, bool mmapModel,
const char* nnapi_cache_dir) {
+ __android_log_print(ANDROID_LOG_INFO, LOG_TAG, "BenchmarkModel %s",
+ modelfile);
mModelFile = modelfile;
- mUseNnApi = use_nnapi;
if (nnapi_cache_dir) {
mCacheDir = nnapi_cache_dir;
}
@@ -135,26 +137,56 @@
// Allow Fp16 precision for all models
mTfliteInterpreter->SetAllowFp16PrecisionForFp32(true);
- if (use_nnapi) {
- tflite::StatefulNnApiDelegate::Options nnapi_options;
- nnapi_options.accelerator_name = nnapi_device_name;
- mTfliteNnapiDelegate = std::make_unique<tflite::StatefulNnApiDelegate>(nnapi_options);
- int delegationStatus = mTfliteInterpreter->ModifyGraphWithDelegate(mTfliteNnapiDelegate.get());
- *nnapiErrno = mTfliteNnapiDelegate->GetNnApiErrno();
- if (delegationStatus != kTfLiteOk ||
- *nnapiErrno != ANEURALNETWORKS_NO_ERROR) {
- __android_log_print(
- ANDROID_LOG_ERROR, LOG_TAG,
- "Failed to initialize NNAPI Delegate for model %s, nnapi_errno is %d",
- modelfile, *nnapiErrno);
- return false;
- }
+ mTfliteBackend = tfliteBackend;
+ switch (mTfliteBackend) {
+ case TFLITE_NNAPI: {
+ tflite::StatefulNnApiDelegate::Options nnapi_options;
+ nnapi_options.accelerator_name = nnapi_device_name;
+ mTfliteNnapiDelegate = std::make_unique<tflite::StatefulNnApiDelegate>(nnapi_options);
+ int delegationStatus = mTfliteInterpreter->ModifyGraphWithDelegate(mTfliteNnapiDelegate.get());
+ *nnapiErrno = mTfliteNnapiDelegate->GetNnApiErrno();
+ if (delegationStatus != kTfLiteOk ||
+ *nnapiErrno != ANEURALNETWORKS_NO_ERROR) {
+ __android_log_print(
+ ANDROID_LOG_ERROR, LOG_TAG,
+ "Failed to initialize NNAPI Delegate for model %s, nnapi_errno is %d",
+ modelfile, *nnapiErrno);
+ return false;
+ }
+ } break;
+ case TFLITE_GPU: {
+#if defined(NN_BENCHMARK_ENABLE_GPU)
+ mGpuDelegate = TfLiteGpuDelegateV2Create(/*default options=*/nullptr);
+ if (mTfliteInterpreter->ModifyGraphWithDelegate(mGpuDelegate) !=
+ kTfLiteOk) {
+ __android_log_print(ANDROID_LOG_ERROR, LOG_TAG,
+ "Failed to initialize GPU Delegate");
+ return false;
+ }
+#else // !defined(NN_BENCHMARK_ENABLE_GPU)
+ __android_log_print(ANDROID_LOG_ERROR, LOG_TAG,
+ "GPU delegate requested but not enabled with "
+ "NN_BENCHMARK_ENABLE_GPU");
+ return false;
+#endif // defined(NN_BENCHMARK_ENABLE_GPU)
+ } break;
+ default:
+ break;
}
return true;
}
-BenchmarkModel::BenchmarkModel() {}
-BenchmarkModel::~BenchmarkModel() {}
+BenchmarkModel::~BenchmarkModel() {
+ switch (mTfliteBackend) {
+ case TFLITE_GPU: {
+#if defined(NN_BENCHMARK_ENABLE_GPU) // !defined(NN_BENCHMARK_ENABLE_GPU)
+ TfLiteGpuDelegateV2Delete(mGpuDelegate);
+#endif // !defined(NN_BENCHMARK_ENABLE_GPU)
+ } break;
+ default:
+ break;
+ }
+}
bool BenchmarkModel::setInput(const uint8_t* dataPtr, size_t length) {
int input = mTfliteInterpreter->inputs()[0];
@@ -362,7 +394,7 @@
// Allow Fp16 precision for all models
interpreter->SetAllowFp16PrecisionForFp32(true);
- if (mUseNnApi) {
+ if (mTfliteBackend == TFLITE_NNAPI) {
tflite::StatefulNnApiDelegate::Options nnapi_options;
nnapi_options.accelerator_name = mNnApiDeviceName.empty() ? nullptr : mNnApiDeviceName.c_str();
if (cacheDir) {
diff --git a/jni/run_tflite.h b/jni/run_tflite.h
index 9e0f086..71c7747 100644
--- a/jni/run_tflite.h
+++ b/jni/run_tflite.h
@@ -17,9 +17,12 @@
#ifndef COM_EXAMPLE_ANDROID_NN_BENCHMARK_RUN_TFLITE_H
#define COM_EXAMPLE_ANDROID_NN_BENCHMARK_RUN_TFLITE_H
+#include "tensorflow/lite/delegates/gpu/delegate.h"
+#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/model.h"
+#include <memory>
#include <unistd.h>
#include <vector>
@@ -83,11 +86,16 @@
PREPARE_FROM_CACHE,
};
+/** TFLite backend. */
+constexpr int TFLITE_CPU = 0;
+constexpr int TFLITE_NNAPI = 1;
+constexpr int TFLITE_GPU = 2;
+
class BenchmarkModel {
public:
~BenchmarkModel();
- static BenchmarkModel* create(const char* modelfile, bool use_nnapi,
+ static BenchmarkModel* create(const char* modelfile, int tfliteBackend,
bool enable_intermediate_tensors_dump,
int* nnapiErrno, const char* nnapi_device_name,
bool mmapModel, const char* nnapi_cache_dir);
@@ -109,8 +117,8 @@
const std::vector<InferenceInOutSequence>& inOutData);
private:
- BenchmarkModel();
- bool init(const char* modelfile, bool use_nnapi,
+ BenchmarkModel() = default;
+ bool init(const char* modelfile, int tfliteBackend,
bool enable_intermediate_tensors_dump,
int* nnapiErrno, const char* nnapi_device_name,
/* flag to choose between memory mapping the model and initializing
@@ -139,9 +147,12 @@
// Parameters for compilation
std::string mModelFile;
- bool mUseNnApi;
std::optional<std::string> mCacheDir;
std::string mNnApiDeviceName;
+#if defined(NN_BENCHMARK_ENABLE_GPU)
+ TfLiteDelegate* mGpuDelegate;
+#endif // defined(NN_BENCHMARK_ENABLE_GPU)
+ int mTfliteBackend;
};
#endif // COM_EXAMPLE_ANDROID_NN_BENCHMARK_RUN_TFLITE_H
diff --git a/src/com/android/nn/benchmark/app/NNBenchmark.java b/src/com/android/nn/benchmark/app/NNBenchmark.java
index a87760d..24878d8 100644
--- a/src/com/android/nn/benchmark/app/NNBenchmark.java
+++ b/src/com/android/nn/benchmark/app/NNBenchmark.java
@@ -27,6 +27,7 @@
import com.android.nn.benchmark.core.BenchmarkResult;
import com.android.nn.benchmark.core.Processor;
import com.android.nn.benchmark.core.TestModels.TestModelEntry;
+import com.android.nn.benchmark.core.TfLiteBackend;
import java.io.IOException;
import java.time.Duration;
import java.util.concurrent.ExecutorService;
@@ -59,7 +60,7 @@
}
public void setUseNNApi(boolean useNNApi) {
- mProcessor.setUseNNApi(useNNApi);
+ mProcessor.setTfLiteBackend(useNNApi ? TfLiteBackend.NNAPI : TfLiteBackend.CPU);
}
public void setCompleteInputSet(boolean completeInputSet) {
@@ -124,7 +125,7 @@
mProcessor = new Processor(this, this, mTestList);
mProcessor.setToggleLong(i.getBooleanExtra(EXTRA_ENABLE_LONG, false));
mProcessor.setTogglePause(i.getBooleanExtra(EXTRA_ENABLE_PAUSE, false));
- mProcessor.setUseNNApi(!i.getBooleanExtra(EXTRA_DISABLE_NNAPI, false));
+ mProcessor.setTfLiteBackend(!i.getBooleanExtra(EXTRA_DISABLE_NNAPI, false) ? TfLiteBackend.NNAPI : TfLiteBackend.CPU);
mProcessor.setMaxRunIterations(i.getIntExtra(EXTRA_MAX_ITERATIONS, 0));
executorService.submit(mProcessor);
} else {
diff --git a/src/com/android/nn/benchmark/core/BenchmarkResult.java b/src/com/android/nn/benchmark/core/BenchmarkResult.java
index 8df486e..a765a39 100644
--- a/src/com/android/nn/benchmark/core/BenchmarkResult.java
+++ b/src/com/android/nn/benchmark/core/BenchmarkResult.java
@@ -27,6 +27,7 @@
import java.util.List;
public class BenchmarkResult implements Parcelable {
+ // Used by CTS tests.
public final static String BACKEND_TFLITE_NNAPI = "TFLite_NNAPI";
public final static String BACKEND_TFLITE_CPU = "TFLite_CPU";
diff --git a/src/com/android/nn/benchmark/core/NNTestBase.java b/src/com/android/nn/benchmark/core/NNTestBase.java
index aeda0b2..2e2e5e3 100644
--- a/src/com/android/nn/benchmark/core/NNTestBase.java
+++ b/src/com/android/nn/benchmark/core/NNTestBase.java
@@ -53,10 +53,11 @@
* @return False if any error occurred, true otherwise
*/
private static native boolean getAcceleratorNames(List<String> resultList);
+ public static native boolean hasNnApiDevice(String nnApiDeviceName);
private synchronized native long initModel(
String modelFileName,
- boolean useNNApi,
+ int tfliteBackend,
boolean enableIntermediateTensorsDump,
String nnApiDeviceName,
boolean mmapModel,
@@ -119,7 +120,7 @@
private final EvaluatorConfig mEvaluatorConfig;
private EvaluatorInterface mEvaluator;
private boolean mHasGoldenOutputs;
- private boolean mUseNNApi = false;
+ private TfLiteBackend mTfLiteBackend;
private boolean mEnableIntermediateTensorsDump = false;
private final int mMinSdkVersion;
private Optional<String> mNNApiDeviceName = Optional.empty();
@@ -152,12 +153,8 @@
mSampleResults = false;
}
- public void useNNApi() {
- useNNApi(true);
- }
-
- public void useNNApi(boolean value) {
- mUseNNApi = value;
+ public void setTfLiteBackend(TfLiteBackend tfLiteBackend) {
+ mTfLiteBackend = tfLiteBackend;
}
public void enableIntermediateTensorsDump() {
@@ -168,8 +165,12 @@
mEnableIntermediateTensorsDump = value;
}
+ public void useNNApi() {
+ setTfLiteBackend(TfLiteBackend.NNAPI);
+ }
+
public void setNNApiDeviceName(String value) {
- if (!mUseNNApi) {
+ if (mTfLiteBackend != TfLiteBackend.NNAPI) {
Log.e(TAG, "Setting device name has no effect when not using NNAPI");
}
mNNApiDeviceName = Optional.ofNullable(value);
@@ -187,13 +188,15 @@
mTemporaryModelFilePath = copyAssetToFile();
String nnApiCacheDir = mContext.getCodeCacheDir().toString();
mModelHandle = initModel(
- mTemporaryModelFilePath, mUseNNApi, mEnableIntermediateTensorsDump,
+ mTemporaryModelFilePath, mTfLiteBackend.ordinal(), mEnableIntermediateTensorsDump,
mNNApiDeviceName.orElse(null), mMmapModel, nnApiCacheDir);
if (mModelHandle == 0) {
Log.e(TAG, "Failed to init the model");
return false;
}
- resizeInputTensors(mModelHandle, mInputShape);
+ if (!resizeInputTensors(mModelHandle, mInputShape)) {
+ return false;
+ }
if (mEvaluatorConfig != null) {
mEvaluator = mEvaluatorConfig.createEvaluator(mContext.getAssets());
@@ -310,30 +313,31 @@
return runBenchmark(getInputOutputAssets(), 0xFFFFFFF, timeoutSec, flags);
}
- /** Run through whole input set (once or mutliple times). */
+ /** Run through whole input set (once or multiple times). */
public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmarkCompleteInputSet(
- int setRepeat,
+ int minInferences,
float timeoutSec)
throws IOException, BenchmarkException {
int flags = getDefaultFlags();
List<InferenceInOutSequence> ios = getInputOutputAssets();
- int totalSequenceInferencesCount = ios.size() * setRepeat;
- int extpectedResults = 0;
+ int setInferences = 0;
for (InferenceInOutSequence iosSeq : ios) {
- extpectedResults += iosSeq.size();
+ setInferences += iosSeq.size();
}
- extpectedResults *= setRepeat;
+ int setRepeat = (minInferences + setInferences - 1) / setInferences; // ceil.
+ int totalSequenceInferencesCount = ios.size() * setRepeat;
+ int expectedResults = setInferences * setRepeat;
Pair<List<InferenceInOutSequence>, List<InferenceResult>> result =
runBenchmark(ios, totalSequenceInferencesCount, timeoutSec,
flags);
- if (result.second.size() != extpectedResults) {
+ if (result.second.size() != expectedResults) {
// We reached a timeout or failed to evaluate whole set for other reason, abort.
@SuppressLint("DefaultLocale")
final String errorMsg = String.format(
"Failed to evaluate complete input set, in %f seconds expected: %d, received:"
+ " %d",
- timeoutSec, extpectedResults, result.second.size());
+ timeoutSec, expectedResults, result.second.size());
Log.w(TAG, errorMsg);
throw new IllegalStateException(errorMsg);
}
diff --git a/src/com/android/nn/benchmark/core/Processor.java b/src/com/android/nn/benchmark/core/Processor.java
index 1cf605c..f8aaa04 100644
--- a/src/com/android/nn/benchmark/core/Processor.java
+++ b/src/com/android/nn/benchmark/core/Processor.java
@@ -53,7 +53,7 @@
private Processor.Callback mCallback;
- private boolean mUseNNApi;
+ private TfLiteBackend mBackend;
private boolean mMmapModel;
private boolean mCompleteInputSet;
private boolean mToggleLong;
@@ -82,10 +82,15 @@
mRunModelCompilationOnly = false;
mMaxRunIterations = 0;
mBenchmarkCompilationCaching = false;
+ mBackend = TfLiteBackend.CPU;
}
public void setUseNNApi(boolean useNNApi) {
- mUseNNApi = useNNApi;
+ setTfLiteBackend(useNNApi ? TfLiteBackend.NNAPI : TfLiteBackend.CPU);
+ }
+
+ public void setTfLiteBackend(TfLiteBackend backend) {
+ mBackend = backend;
}
public void setCompleteInputSet(boolean completeInputSet) {
@@ -156,7 +161,7 @@
public static boolean isTestModelSupportedByAccelerator(Context context,
TestModels.TestModelEntry testModelEntry, String acceleratorName)
throws NnApiDelegationFailure {
- try (NNTestBase tb = testModelEntry.createNNTestBase(/*useNnnapi=*/ true,
+ try (NNTestBase tb = testModelEntry.createNNTestBase(TfLiteBackend.NNAPI,
/*enableIntermediateTensorsDump=*/false,
/*mmapModel=*/ false)) {
tb.setNNApiDeviceName(acceleratorName);
@@ -183,9 +188,9 @@
// Make sure we don't leak memory.
oldTestBase.destroy();
}
- NNTestBase tb = t.createNNTestBase(mUseNNApi, /*enableIntermediateTensorsDump=*/false,
+ NNTestBase tb = t.createNNTestBase(mBackend, /*enableIntermediateTensorsDump=*/false,
mMmapModel);
- if (mUseNNApi) {
+ if (mBackend == TfLiteBackend.NNAPI) {
tb.setNNApiDeviceName(mAcceleratorName);
}
if (!tb.setupModel(mContext)) {
@@ -212,9 +217,7 @@
}
return BenchmarkResult.fromInferenceResults(
mTest.getTestInfo(),
- mUseNNApi
- ? BenchmarkResult.BACKEND_TFLITE_NNAPI
- : BenchmarkResult.BACKEND_TFLITE_CPU,
+ mBackend.toString(),
results.first,
results.second,
mTest.getEvaluator());
@@ -274,7 +277,7 @@
}
// Compilation benchmark
- if (mUseNNApi && mBenchmarkCompilationCaching) {
+ if (mBenchmarkCompilationCaching) {
runCompilationBenchmarkLoop(mCompilationBenchmarkWarmupTimeSeconds,
mCompilationBenchmarkRunTimeSeconds, mCompilationBenchmarkMaxIterations, r);
}
@@ -371,9 +374,8 @@
if (mRunModelCompilationOnly) {
mTestResults[ct] = BenchmarkResult.fromInferenceResults(testModel.mTestName,
- mUseNNApi
- ? BenchmarkResult.BACKEND_TFLITE_NNAPI
- : BenchmarkResult.BACKEND_TFLITE_CPU, Collections.emptyList(),
+ mBackend.toString(),
+ Collections.emptyList(),
Collections.emptyList(), null);
} else {
// Run the test
diff --git a/src/com/android/nn/benchmark/core/TestModels.java b/src/com/android/nn/benchmark/core/TestModels.java
index cd386ba..9eb90f5 100644
--- a/src/com/android/nn/benchmark/core/TestModels.java
+++ b/src/com/android/nn/benchmark/core/TestModels.java
@@ -76,14 +76,20 @@
mEvaluator, mMinSdkVersion);
}
- public NNTestBase createNNTestBase(boolean useNNApi, boolean enableIntermediateTensorsDump) {
- return createNNTestBase(useNNApi, enableIntermediateTensorsDump, /*mmapModel=*/false);
+ public NNTestBase createNNTestBase(TfLiteBackend tfLiteBackend, boolean enableIntermediateTensorsDump) {
+ return createNNTestBase(tfLiteBackend, enableIntermediateTensorsDump, /*mmapModel=*/false);
}
- public NNTestBase createNNTestBase(boolean useNNApi, boolean enableIntermediateTensorsDump,
+ // Used by CTS tests.
+ public NNTestBase createNNTestBase(boolean useNNAPI, boolean enableIntermediateTensorsDump) {
+ TfLiteBackend tfLiteBackend = useNNAPI ? TfLiteBackend.NNAPI : TfLiteBackend.CPU;
+ return createNNTestBase(tfLiteBackend, enableIntermediateTensorsDump, /*mmapModel=*/false);
+ }
+
+ public NNTestBase createNNTestBase(TfLiteBackend tfLiteBackend, boolean enableIntermediateTensorsDump,
boolean mmapModel) {
NNTestBase test = createNNTestBase();
- test.useNNApi(useNNApi);
+ test.setTfLiteBackend(tfLiteBackend);
test.enableIntermediateTensorsDump(enableIntermediateTensorsDump);
test.setMmapModel(mmapModel);
return test;
diff --git a/src/com/android/nn/benchmark/core/TfLiteBackend.java b/src/com/android/nn/benchmark/core/TfLiteBackend.java
new file mode 100644
index 0000000..cb385c2
--- /dev/null
+++ b/src/com/android/nn/benchmark/core/TfLiteBackend.java
@@ -0,0 +1,41 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.nn.benchmark.core;
+
+public enum TfLiteBackend {
+ // The order of the values should be consistent with the constants
+ // in jni/run_tflite.h
+ CPU("TFLite_CPU"),
+ NNAPI("TFLite_NNAPI"),
+ GPU("TFLite_GPU");
+
+ private final String mName;
+
+ private TfLiteBackend(String name) { mName = name; }
+
+ public static TfLiteBackend parseString(String s) {
+ for (TfLiteBackend backend : TfLiteBackend.values()) {
+ if (backend.toString().equals(s)) {
+ return backend;
+ }
+ }
+ return CPU;
+ }
+
+ @Override
+ public String toString() { return mName; }
+}
diff --git a/src/com/android/nn/benchmark/util/DumpIntermediateTensors.java b/src/com/android/nn/benchmark/util/DumpIntermediateTensors.java
index 59ea610..5c4197c 100644
--- a/src/com/android/nn/benchmark/util/DumpIntermediateTensors.java
+++ b/src/com/android/nn/benchmark/util/DumpIntermediateTensors.java
@@ -23,6 +23,7 @@
import com.android.nn.benchmark.core.NNTestBase;
import com.android.nn.benchmark.core.TestModels;
import com.android.nn.benchmark.core.TestModels.TestModelEntry;
+import com.android.nn.benchmark.core.TfLiteBackend;
import java.io.File;
@@ -44,6 +45,7 @@
public static final String EXTRA_MODEL_NAME = "modelName";
public static final String EXTRA_INPUT_ASSET_INDEX = "inputAssetIndex";
public static final String EXTRA_INPUT_ASSET_SIZE = "inputAssetSize";
+ public static final String EXTRA_TFLITE_BACKEND = "tfLiteBackend";
public static final String DUMP_DIR = "intermediate";
public static final String CPU_DIR = "cpu";
public static final String NNAPI_DIR = "nnapi";
@@ -89,9 +91,10 @@
// Run in CPU and NNAPI mode
for (final boolean useNNAPI : new boolean[]{false, true}) {
String useNNAPIDir = useNNAPI ? NNAPI_DIR : CPU_DIR;
+ TfLiteBackend backend = useNNAPI ? TfLiteBackend.NNAPI : TfLiteBackend.CPU;
TestModelEntry modelEntry = TestModels.getModelByName(modelName);
try (NNTestBase testBase = modelEntry.createNNTestBase(
- useNNAPI, /*enableIntermediateTensorsDump*/true, /*mmapModel*/false)) {
+ backend, /*enableIntermediateTensorsDump*/true, /*mmapModel*/false)) {
testBase.setupModel(this);
File outputDir = new File(getFilesDir() + "/" + DUMP_DIR +
"/" + modelName, useNNAPIDir);
diff --git a/src/com/android/nn/crashtest/app/AcceleratorSpecificTestSupport.java b/src/com/android/nn/crashtest/app/AcceleratorSpecificTestSupport.java
index e4da19a..68e35d9 100644
--- a/src/com/android/nn/crashtest/app/AcceleratorSpecificTestSupport.java
+++ b/src/com/android/nn/crashtest/app/AcceleratorSpecificTestSupport.java
@@ -27,6 +27,7 @@
import com.android.nn.benchmark.core.NnApiDelegationFailure;
import com.android.nn.benchmark.core.Processor;
import com.android.nn.benchmark.core.TestModels;
+import com.android.nn.benchmark.core.TfLiteBackend;
import java.io.IOException;
import java.util.ArrayList;
@@ -122,7 +123,7 @@
public void onStatusUpdate(int testNumber, int numTests, String modelName) {
}
}, new int[0]);
- mProcessor.setUseNNApi(true);
+ mProcessor.setTfLiteBackend(TfLiteBackend.NNAPI);
mProcessor.setCompleteInputSet(false);
mProcessor.setNnApiAcceleratorName(acceleratorName);
mTestModelEntry = testModelEntry;
@@ -152,4 +153,4 @@
return true;
}
}
-}
\ No newline at end of file
+}
diff --git a/src/com/android/nn/crashtest/core/test/PerformanceDegradationTest.java b/src/com/android/nn/crashtest/core/test/PerformanceDegradationTest.java
index 038d554..c203441 100644
--- a/src/com/android/nn/crashtest/core/test/PerformanceDegradationTest.java
+++ b/src/com/android/nn/crashtest/core/test/PerformanceDegradationTest.java
@@ -25,6 +25,7 @@
import com.android.nn.benchmark.core.BenchmarkResult;
import com.android.nn.benchmark.core.Processor;
import com.android.nn.benchmark.core.TestModels;
+import com.android.nn.benchmark.core.TfLiteBackend;
import com.android.nn.crashtest.app.AcceleratorSpecificTestSupport;
import com.android.nn.crashtest.core.CrashTest;
import com.android.nn.crashtest.core.CrashTestCoordinator;
@@ -226,7 +227,7 @@
final TestModels.TestModelEntry inferenceModelEntry, final CountDownLatch start) {
return () -> {
Processor benchmarkProcessor = new Processor(mContext, mNoOpCallback, new int[0]);
- benchmarkProcessor.setUseNNApi(true);
+ benchmarkProcessor.setTfLiteBackend(TfLiteBackend.NNAPI);
benchmarkProcessor.setNnApiAcceleratorName(mAcceleratorName);
if (start != null) {
start.countDown();
@@ -252,7 +253,7 @@
mStart = start;
mTestModelEntry = testModelEntry;
mProcessor = new Processor(context, mNoOpCallback, new int[0]);
- mProcessor.setUseNNApi(true);
+ mProcessor.setTfLiteBackend(TfLiteBackend.NNAPI);
mProcessor.setNnApiAcceleratorName(acceleratorName);
mProcessor.setRunModelCompilationOnly(true);
mRun = true;
diff --git a/src/com/android/nn/crashtest/core/test/RunModelsInParallel.java b/src/com/android/nn/crashtest/core/test/RunModelsInParallel.java
index 5fe204b..9f01128 100644
--- a/src/com/android/nn/crashtest/core/test/RunModelsInParallel.java
+++ b/src/com/android/nn/crashtest/core/test/RunModelsInParallel.java
@@ -26,6 +26,7 @@
import com.android.nn.benchmark.core.Processor;
import com.android.nn.crashtest.core.CrashTest;
import com.android.nn.crashtest.core.CrashTestCoordinator.CrashTestIntentInitializer;
+import com.android.nn.benchmark.core.TfLiteBackend;
import java.time.Duration;
import java.util.ArrayList;
@@ -131,7 +132,7 @@
public void onStatusUpdate(int testNumber, int numTests, String modelName) {
}
}, testList);
- result.setUseNNApi(true);
+ result.setTfLiteBackend(TfLiteBackend.NNAPI);
result.setCompleteInputSet(false);
result.setNnApiAcceleratorName(mAcceleratorName);
result.setIgnoreUnsupportedModels(mIgnoreUnsupportedModels);