blob: a265cb5dc0f2f3fee529aab602564cb272259d13 [file] [log] [blame]
/*
* Copyright (C) 2019 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.nn.benchmark.core;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import android.content.Context;
import android.os.Trace;
import android.util.Log;
import android.util.Pair;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
/** Processor is a helper thread for running the work without blocking the UI thread. */
public class Processor implements Runnable {
public interface Callback {
void onBenchmarkFinish(boolean ok);
void onStatusUpdate(int testNumber, int numTests, String modelName);
}
protected static final String TAG = "NN_BENCHMARK";
private Context mContext;
private final AtomicBoolean mRun = new AtomicBoolean(true);
volatile boolean mHasBeenStarted = false;
// You cannot restart a thread, so the completion flag is final
private final CountDownLatch mCompleted = new CountDownLatch(1);
private NNTestBase mTest;
private int mTestList[];
private BenchmarkResult mTestResults[];
private Processor.Callback mCallback;
private boolean mUseNNApi;
private boolean mCompleteInputSet;
private boolean mToggleLong;
private boolean mTogglePause;
private String mAcceleratorName;
private boolean mIgnoreUnsupportedModels;
private boolean mRunModelCompilationOnly;
public Processor(Context context, Processor.Callback callback, int[] testList) {
mContext = context;
mCallback = callback;
mTestList = testList;
if (mTestList != null) {
mTestResults = new BenchmarkResult[mTestList.length];
}
mAcceleratorName = null;
mIgnoreUnsupportedModels = false;
mRunModelCompilationOnly = false;
}
public void setUseNNApi(boolean useNNApi) {
mUseNNApi = useNNApi;
}
public void setCompleteInputSet(boolean completeInputSet) {
mCompleteInputSet = completeInputSet;
}
public void setToggleLong(boolean toggleLong) {
mToggleLong = toggleLong;
}
public void setTogglePause(boolean togglePause) {
mTogglePause = togglePause;
}
public void setNnApiAcceleratorName(String acceleratorName) {
mAcceleratorName = acceleratorName;
}
public void setIgnoreUnsupportedModels(boolean value) {
mIgnoreUnsupportedModels = value;
}
public void setRunModelCompilationOnly(boolean value) {
mRunModelCompilationOnly = value;
}
// Method to retrieve benchmark results for instrumentation tests.
public BenchmarkResult getInstrumentationResult(
TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds)
throws IOException, BenchmarkException {
mTest = changeTest(mTest, t);
BenchmarkResult result = getBenchmark(warmupTimeSeconds, runTimeSeconds);
mTest.destroy();
mTest = null;
return result;
}
public static boolean isTestModelSupportedByAccelerator(Context context,
TestModels.TestModelEntry testModelEntry, String acceleratorName) {
NNTestBase tb = testModelEntry.createNNTestBase(true,
false /* enableIntermediateTensorsDump */);
tb.setNNApiDeviceName(acceleratorName);
try {
return tb.setupModel(context);
} finally {
tb.destroy();
}
}
private NNTestBase changeTest(NNTestBase oldTestBase, TestModels.TestModelEntry t)
throws UnsupportedModelException {
if (oldTestBase != null) {
// Make sure we don't leak memory.
oldTestBase.destroy();
}
NNTestBase tb = t.createNNTestBase(mUseNNApi, false /* enableIntermediateTensorsDump */);
if (mUseNNApi) {
tb.setNNApiDeviceName(mAcceleratorName);
}
if (!tb.setupModel(mContext)) {
throw new UnsupportedModelException("Cannot initialise model");
}
return tb;
}
// Run one loop of kernels for at least the specified minimum time.
// The function returns the average time in ms for the test run
private BenchmarkResult runBenchmarkLoop(float minTime, boolean completeInputSet)
throws IOException {
try {
// Run the kernel
Pair<List<InferenceInOutSequence>, List<InferenceResult>> results;
if (minTime > 0.f) {
if (completeInputSet) {
results = mTest.runBenchmarkCompleteInputSet(1, minTime);
} else {
results = mTest.runBenchmark(minTime);
}
} else {
results = mTest.runInferenceOnce();
}
return BenchmarkResult.fromInferenceResults(
mTest.getTestInfo(),
mUseNNApi
? BenchmarkResult.BACKEND_TFLITE_NNAPI
: BenchmarkResult.BACKEND_TFLITE_CPU,
results.first,
results.second,
mTest.getEvaluator());
} catch (BenchmarkException e) {
return new BenchmarkResult(e.getMessage());
}
}
public BenchmarkResult[] getTestResults() {
return mTestResults;
}
// Get a benchmark result for a specific test
private BenchmarkResult getBenchmark(float warmupTimeSeconds, float runTimeSeconds)
throws IOException {
try {
mTest.checkSdkVersion();
} catch (UnsupportedSdkException e) {
BenchmarkResult r = new BenchmarkResult(e.getMessage());
Log.w(TAG, "Unsupported SDK for test: " + r.toString());
return r;
}
// We run a short bit of work before starting the actual test
// this is to let any power management do its job and respond.
// For NNAPI systrace usage documentation, see
// frameworks/ml/nn/common/include/Tracing.h.
try {
final String traceName = "[NN_LA_PWU]runBenchmarkLoop";
Trace.beginSection(traceName);
runBenchmarkLoop(warmupTimeSeconds, false);
} finally {
Trace.endSection();
}
// Run the actual benchmark
BenchmarkResult r;
try {
final String traceName = "[NN_LA_PBM]runBenchmarkLoop";
Trace.beginSection(traceName);
r = runBenchmarkLoop(runTimeSeconds, mCompleteInputSet);
} finally {
Trace.endSection();
}
return r;
}
@Override
public void run() {
mHasBeenStarted = true;
Log.d(TAG, "Processor starting");
boolean success = true;
try {
while (mRun.get()) {
try {
benchmarkAllModels();
Log.d(TAG, "Processor completed work");
} catch (IOException | BenchmarkException e) {
Log.e(TAG, "Exception during benchmark run", e);
success = false;
break;
} catch (Throwable e) {
Log.e(TAG, "Error during execution", e);
throw e;
}
}
mCallback.onBenchmarkFinish(success);
} finally {
mCompleted.countDown();
}
}
private void benchmarkAllModels() throws IOException, BenchmarkException {
// Loop over the tests we want to benchmark
for (int ct = 0; ct < mTestList.length; ct++) {
if (!mRun.get()) {
Log.v(TAG, String.format("Asked to stop execution at model #%d", ct));
break;
}
// For reproducibility we wait a short time for any sporadic work
// created by the user touching the screen to launch the test to pass.
// Also allows for things to settle after the test changes.
try {
Thread.sleep(250);
} catch (InterruptedException ignored) {
Thread.currentThread().interrupt();
break;
}
TestModels.TestModelEntry testModel =
TestModels.modelsList().get(mTestList[ct]);
int testNumber = ct + 1;
mCallback.onStatusUpdate(testNumber, mTestList.length,
testModel.toString());
// Select the next test
try {
mTest = changeTest(mTest, testModel);
} catch (UnsupportedModelException e) {
if (mIgnoreUnsupportedModels) {
Log.d(TAG, String.format(
"Cannot initialise test %d: '%s' on accelerator %s, skipping", ct,
testModel.mTestName, mAcceleratorName));
} else {
Log.e(TAG, String.format("Cannot initialise test %d: '%s' on accelerator %s.", ct,
testModel.mTestName, mAcceleratorName), e);
throw e;
}
}
// If the user selected the "long pause" option, wait
if (mTogglePause) {
for (int i = 0; (i < 100) && mRun.get(); i++) {
try {
Thread.sleep(100);
} catch (InterruptedException ignored) {
Thread.currentThread().interrupt();
break;
}
}
}
if (mRunModelCompilationOnly) {
mTestResults[ct] = BenchmarkResult.fromInferenceResults(testModel.mTestName,
mUseNNApi
? BenchmarkResult.BACKEND_TFLITE_NNAPI
: BenchmarkResult.BACKEND_TFLITE_CPU, Collections.emptyList(),
Collections.emptyList(), null);
} else {
// Run the test
float warmupTime = 0.3f;
float runTime = 1.f;
if (mToggleLong) {
warmupTime = 2.f;
runTime = 10.f;
}
mTestResults[ct] = getBenchmark(warmupTime, runTime);
}
}
}
public void exit() {
exitWithTimeout(-1l);
}
public void exitWithTimeout(long timeoutMs) {
mRun.set(false);
if (mHasBeenStarted) {
Log.d(TAG, String.format("Terminating, timeout is %d ms", timeoutMs));
try {
if (timeoutMs > 0) {
boolean hasCompleted = mCompleted.await(timeoutMs, MILLISECONDS);
if (!hasCompleted) {
Log.w(TAG, "Exiting before execution actually completed");
}
} else {
mCompleted.await();
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
Log.w(TAG, "Interrupted while waiting for Processor to complete", e);
}
}
Log.d(TAG, "Done, cleaning up");
if (mTest != null) {
mTest.destroy();
mTest = null;
}
}
}