Refactor code to prepare for parallel tests am: 2aaa3f84b3 am: 699747fbf7
Change-Id: I38c38fed456634478a2d31dfa68a43d300ba13d4
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..1519f47
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,4 @@
+*.iml
+**/gen/*
+**/.idea/*
+
diff --git a/dogfood/src/com/android/nn/dogfood/BenchmarkJobService.java b/dogfood/src/com/android/nn/dogfood/BenchmarkJobService.java
index 4334d84..21ce236 100644
--- a/dogfood/src/com/android/nn/dogfood/BenchmarkJobService.java
+++ b/dogfood/src/com/android/nn/dogfood/BenchmarkJobService.java
@@ -28,12 +28,13 @@
import androidx.core.app.NotificationManagerCompat;
import com.android.nn.benchmark.core.BenchmarkResult;
-import com.android.nn.benchmark.core.NNTestBase;
import com.android.nn.benchmark.core.Processor;
import com.android.nn.benchmark.core.TestModels;
import java.util.List;
import java.util.Random;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
/** Regularly runs a random selection of the NN API benchmark models */
public class BenchmarkJobService extends JobService implements Processor.Callback {
@@ -52,6 +53,7 @@
private static int DOGFOOD_MODELS_PER_RUN = 20;
private BenchmarkResult mTestResults[];
+ private final ExecutorService processorRunner = Executors.newSingleThreadExecutor();
@Override
@@ -74,11 +76,10 @@
}
public void doBenchmark() {
-
mProcessor = new Processor(this, this, randomModelList());
mProcessor.setUseNNApi(true);
mProcessor.setToggleLong(true);
- mProcessor.start();
+ processorRunner.submit(mProcessor);
}
public void onBenchmarkFinish(boolean ok) {
diff --git a/res/values/strings.xml b/res/values/strings.xml
index 180e0cb..aab5c97 100644
--- a/res/values/strings.xml
+++ b/res/values/strings.xml
@@ -33,7 +33,7 @@
<string name="ok">Ok</string>
<string name="cancel">Cancel</string>
<string name="settings">settings</string>
- <string-array
+ <string-array
name="settings_array">
<item>Run each test longer, 10 seconds</item>
<item>Pause 10 seconds between tests</item>
diff --git a/src/com/android/nn/benchmark/app/BenchmarkTestBase.java b/src/com/android/nn/benchmark/app/BenchmarkTestBase.java
index 5aaa056..a89626f 100644
--- a/src/com/android/nn/benchmark/app/BenchmarkTestBase.java
+++ b/src/com/android/nn/benchmark/app/BenchmarkTestBase.java
@@ -23,7 +23,6 @@
import android.content.Intent;
import android.content.IntentFilter;
import android.os.BatteryManager;
-import android.os.Bundle;
import android.os.Trace;
import android.test.ActivityInstrumentationTestCase2;
import android.util.Log;
@@ -35,6 +34,8 @@
import com.android.nn.benchmark.core.TestModels;
import com.android.nn.benchmark.core.TestModels.TestModelEntry;
+import java.util.concurrent.CountDownLatch;
+
import org.junit.After;
import org.junit.Before;
import org.junit.runner.RunWith;
@@ -51,6 +52,7 @@
*/
@RunWith(Parameterized.class)
public class BenchmarkTestBase extends ActivityInstrumentationTestCase2<NNBenchmark> {
+
// Only run 1 iteration now to fit the MediumTest time requirement.
// One iteration means running the tests continuous for 1s.
private NNBenchmark mActivity;
@@ -95,33 +97,32 @@
protected void waitUntilCharged() {
Log.v(NNBenchmark.TAG, "Waiting for the device to charge");
- Object lock = new Object();
+ final CountDownLatch chargedLatch = new CountDownLatch(1);
BroadcastReceiver receiver = new BroadcastReceiver() {
@Override
public void onReceive(Context context, Intent intent) {
int level = intent.getIntExtra(BatteryManager.EXTRA_LEVEL, -1);
int scale = intent.getIntExtra(BatteryManager.EXTRA_SCALE, -1);
- int percentage = level * 100 / scale;
+ int percentage = level * 100 / scale;
Log.v(NNBenchmark.TAG, "Battery level: " + percentage + "%");
int status = intent.getIntExtra(BatteryManager.EXTRA_STATUS, -1);
if (status == BatteryManager.BATTERY_STATUS_FULL) {
- synchronized (lock) {
- lock.notify();
- }
+ chargedLatch.countDown();
} else if (status != BatteryManager.BATTERY_STATUS_CHARGING) {
- Log.e(NNBenchmark.TAG, "Device is not charging");
+ Log.e(NNBenchmark.TAG,
+ String.format("Device is not charging, status is %d", status));
}
}
};
mActivity.registerReceiver(receiver, new IntentFilter(Intent.ACTION_BATTERY_CHANGED));
- synchronized (lock) {
- try {
- lock.wait();
- } catch (InterruptedException e) {
- }
+ try {
+ chargedLatch.await();
+ } catch (InterruptedException ignored) {
+ Thread.currentThread().interrupt();
}
+
mActivity.unregisterReceiver(receiver);
}
@@ -139,34 +140,49 @@
super.tearDown();
}
- class TestAction implements Runnable {
- TestModelEntry mTestModel;
+ interface Joinable extends Runnable {
+ // Syncrhonises the caller with the completion of the current action
+ void join();
+ }
+
+ class TestAction implements Joinable {
+
+ private final TestModelEntry mTestModel;
+ private final float mWarmupTimeSeconds;
+ private final float mRunTimeSeconds;
+ private final CountDownLatch actionComplete;
+
BenchmarkResult mResult;
- float mWarmupTimeSeconds;
- float mRunTimeSeconds;
Throwable mException;
- public TestAction(TestModelEntry testName) {
- mTestModel = testName;
- }
public TestAction(TestModelEntry testName, float warmupTimeSeconds, float runTimeSeconds) {
mTestModel = testName;
mWarmupTimeSeconds = warmupTimeSeconds;
mRunTimeSeconds = runTimeSeconds;
+ actionComplete = new CountDownLatch(1);
}
public void run() {
+ Log.v(NNBenchmark.TAG, String.format(
+ "Starting benchmark for test '%s' running for at least %f seconds",
+ mTestModel.mTestName,
+ mRunTimeSeconds));
try {
- mResult = mActivity.mProcessor.getInstrumentationResult(
- mTestModel, mWarmupTimeSeconds, mRunTimeSeconds);
- } catch (IOException e) {
+ mResult = mActivity.runSynchronously(
+ mTestModel, mWarmupTimeSeconds, mRunTimeSeconds);
+ Log.v(NNBenchmark.TAG,
+ String.format("Benchmark for test '%s' is: %s", mTestModel, mResult));
+ } catch (BenchmarkException | IOException e) {
mException = e;
- e.printStackTrace();
- }
- Log.v(NNBenchmark.TAG,
- "Benchmark for test \"" + mTestModel.toString() + "\" is: " + mResult);
- synchronized (this) {
- this.notify();
+ Log.e(NNBenchmark.TAG,
+ String.format("Error running Benchmark for test '%s'", mTestModel), e);
+ } catch (Throwable e) {
+ mException = e;
+ Log.e(NNBenchmark.TAG,
+ String.format("Failure running Benchmark for test '%s'!!", mTestModel), e);
+ throw e;
+ } finally {
+ actionComplete.countDown();
}
}
@@ -176,20 +192,23 @@
}
return mResult;
}
+
+ @Override
+ public void join() {
+ try {
+ actionComplete.await();
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ Log.v(NNBenchmark.TAG, "Interrupted while waiting for action running", e);
+ }
+ }
}
// Set the benchmark thread to run on ui thread
// Synchronized the thread such that the test will wait for the benchmark thread to finish
- public void runOnUiThread(Runnable action) {
- synchronized (action) {
- mActivity.runOnUiThread(action);
- try {
- action.wait();
- } catch (InterruptedException e) {
- Log.v(NNBenchmark.TAG, "waiting for action running on UI thread is interrupted: " +
- e.toString());
- }
- }
+ public void runOnUiThread(Joinable action) {
+ mActivity.runOnUiThread(action);
+ action.join();
}
public void runTest(TestAction ta, String testName) {
diff --git a/src/com/android/nn/benchmark/app/NNBenchmark.java b/src/com/android/nn/benchmark/app/NNBenchmark.java
index a45e043..2ea4478 100644
--- a/src/com/android/nn/benchmark/app/NNBenchmark.java
+++ b/src/com/android/nn/benchmark/app/NNBenchmark.java
@@ -19,33 +19,38 @@
import android.app.Activity;
import android.content.Intent;
import android.os.Bundle;
+import android.util.Log;
import android.view.WindowManager;
import android.widget.TextView;
-
+import com.android.nn.benchmark.core.BenchmarkException;
import com.android.nn.benchmark.core.BenchmarkResult;
import com.android.nn.benchmark.core.Processor;
+import com.android.nn.benchmark.core.TestModels.TestModelEntry;
+import java.io.IOException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
public class NNBenchmark extends Activity implements Processor.Callback {
- protected static final String TAG = "NN_BENCHMARK";
+ public static final String TAG = "NN_BENCHMARK";
public static final String EXTRA_ENABLE_LONG = "enable long";
public static final String EXTRA_ENABLE_PAUSE = "enable pause";
public static final String EXTRA_DISABLE_NNAPI = "disable NNAPI";
- public static final String EXTRA_DEMO = "demo";
public static final String EXTRA_TESTS = "tests";
public static final String EXTRA_RESULTS_TESTS = "tests";
public static final String EXTRA_RESULTS_RESULTS = "results";
private int mTestList[];
- private BenchmarkResult mTestResults[];
+
+ private Processor mProcessor;
+ private final ExecutorService executorService = Executors.newSingleThreadExecutor();
private TextView mTextView;
// Initialize the parameters for Instrumentation tests.
protected void prepareInstrumentationTest() {
mTestList = new int[1];
- mTestResults = new BenchmarkResult[1];
mProcessor = new Processor(this, this, mTestList);
}
@@ -57,9 +62,6 @@
mProcessor.setCompleteInputSet(completeInputSet);
}
- private boolean mDoingBenchmark;
- public Processor mProcessor;
-
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
@@ -74,7 +76,8 @@
protected void onPause() {
super.onPause();
if (mProcessor != null) {
- mProcessor.exit();
+ mProcessor.exitWithTimeout(30000l);
+ mProcessor = null;
}
}
@@ -104,12 +107,15 @@
super.onResume();
Intent i = getIntent();
mTestList = i.getIntArrayExtra(EXTRA_TESTS);
- mProcessor = new Processor(this, this, mTestList);
- mProcessor.setToggleLong(i.getBooleanExtra(EXTRA_ENABLE_LONG, false));
- mProcessor.setTogglePause(i.getBooleanExtra(EXTRA_ENABLE_PAUSE, false));
- mProcessor.setUseNNApi(!i.getBooleanExtra(EXTRA_DISABLE_NNAPI, false));
- if (mTestList != null) {
- mProcessor.start();
+ if (mTestList != null && mTestList.length > 0) {
+ Log.v(TAG, String.format("Starting benchmark with %d test", mTestList.length));
+ mProcessor = new Processor(this, this, mTestList);
+ mProcessor.setToggleLong(i.getBooleanExtra(EXTRA_ENABLE_LONG, false));
+ mProcessor.setTogglePause(i.getBooleanExtra(EXTRA_ENABLE_PAUSE, false));
+ mProcessor.setUseNNApi(!i.getBooleanExtra(EXTRA_DISABLE_NNAPI, false));
+ executorService.submit(mProcessor);
+ } else {
+ Log.v(TAG, "No test to run, doing nothing");
}
}
@@ -117,4 +123,9 @@
protected void onDestroy() {
super.onDestroy();
}
+
+ public BenchmarkResult runSynchronously(TestModelEntry testModel,
+ float warmupTimeSeconds, float runTimeSeconds) throws IOException, BenchmarkException {
+ return mProcessor.getInstrumentationResult(testModel, warmupTimeSeconds, runTimeSeconds);
+ }
}
diff --git a/src/com/android/nn/benchmark/app/NNScoringTest.java b/src/com/android/nn/benchmark/app/NNScoringTest.java
index 1c339cf..4932eb1 100644
--- a/src/com/android/nn/benchmark/app/NNScoringTest.java
+++ b/src/com/android/nn/benchmark/app/NNScoringTest.java
@@ -36,6 +36,7 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
+import android.util.Log;
/**
* Tests that run all models/datasets/backend that are required for scoring the device.
@@ -61,17 +62,15 @@
super.prepareTest();
}
- @Test
- @LargeTest
- public void testTFLite() throws IOException {
+ private void test(boolean useNnapi) throws IOException {
if (!TestExternalStorageActivity.testWriteExternalStorage(getActivity(), false)) {
throw new IOException("No permission to store results in external storage");
}
- setUseNNApi(false);
+ setUseNNApi(useNnapi);
setCompleteInputSet(true);
TestAction ta = new TestAction(mModel, WARMUP_REPEATABLE_SECONDS,
- COMPLETE_SET_TIMEOUT_SECOND);
+ COMPLETE_SET_TIMEOUT_SECOND);
runTest(ta, mModel.getTestName());
try (CSVWriter writer = new CSVWriter(getLocalCSVFile())) {
@@ -81,21 +80,14 @@
@Test
@LargeTest
+ public void testTFLite() throws IOException {
+ test(false);
+ }
+
+ @Test
+ @LargeTest
public void testNNAPI() throws IOException {
- if (!TestExternalStorageActivity.testWriteExternalStorage(getActivity(), false)) {
- throw new IOException("No permission to store results in external storage");
- }
-
- setUseNNApi(true);
- setCompleteInputSet(true);
- TestAction ta = new TestAction(mModel, WARMUP_REPEATABLE_SECONDS,
- COMPLETE_SET_TIMEOUT_SECOND);
- runTest(ta, mModel.getTestName());
-
-
- try (CSVWriter writer = new CSVWriter(getLocalCSVFile())) {
- writer.write(ta.getBenchmark());
- }
+ test(true);
}
public static File getLocalCSVFile() {
diff --git a/src/com/android/nn/benchmark/core/NNTestBase.java b/src/com/android/nn/benchmark/core/NNTestBase.java
index 579b4a2..21f8711 100644
--- a/src/com/android/nn/benchmark/core/NNTestBase.java
+++ b/src/com/android/nn/benchmark/core/NNTestBase.java
@@ -262,12 +262,13 @@
Pair<List<InferenceInOutSequence>, List<InferenceResult>> result =
runBenchmark(ios, totalSequenceInferencesCount, timeoutSec,
flags);
- if (result.second.size() != extpectedResults ) {
+ if (result.second.size() != extpectedResults) {
// We reached a timeout or failed to evaluate whole set for other reason, abort.
- throw new IllegalStateException(
- "Failed to evaluate complete input set, expected: "
- + extpectedResults +
- ", received: " + result.second.size());
+ final String errorMsg = "Failed to evaluate complete input set, expected: "
+ + extpectedResults +
+ ", received: " + result.second.size();
+ Log.w(TAG, errorMsg);
+ throw new IllegalStateException(errorMsg);
}
return result;
}
@@ -283,7 +284,7 @@
}
List<InferenceResult> resultList = new ArrayList<>();
if (!runBenchmark(mModelHandle, inOutList, resultList, inferencesSeqMaxCount,
- timeoutSec, flags)) {
+ timeoutSec, flags)) {
throw new BenchmarkException("Failed to run benchmark");
}
return new Pair<List<InferenceInOutSequence>, List<InferenceResult>>(
@@ -303,21 +304,18 @@
String modelAssetName = mModelFile + ".tflite";
AssetManager assetManager = mContext.getAssets();
try {
- InputStream in = assetManager.open(modelAssetName);
-
outFileName = mContext.getCacheDir().getAbsolutePath() + "/" + modelAssetName;
File outFile = new File(outFileName);
- OutputStream out = new FileOutputStream(outFile);
- byte[] buffer = new byte[1024];
- int read;
- while ((read = in.read(buffer)) != -1) {
- out.write(buffer, 0, read);
+ try (InputStream in = assetManager.open(modelAssetName);
+ FileOutputStream out = new FileOutputStream(outFile)) {
+
+ byte[] byteBuffer = new byte[1024];
+ int readBytes = -1;
+ while ((readBytes = in.read(byteBuffer)) != -1) {
+ out.write(byteBuffer, 0, readBytes);
+ }
}
- out.flush();
-
- in.close();
- out.close();
} catch (IOException e) {
Log.e(TAG, "Failed to copy asset file: " + modelAssetName, e);
return null;
diff --git a/src/com/android/nn/benchmark/core/Processor.java b/src/com/android/nn/benchmark/core/Processor.java
index 1aa6008..0017082 100644
--- a/src/com/android/nn/benchmark/core/Processor.java
+++ b/src/com/android/nn/benchmark/core/Processor.java
@@ -16,6 +16,8 @@
package com.android.nn.benchmark.core;
+import static java.util.concurrent.TimeUnit.MILLISECONDS;
+
import android.content.Context;
import android.os.Trace;
import android.util.Log;
@@ -23,21 +25,26 @@
import java.io.IOException;
import java.util.List;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.atomic.AtomicBoolean;
/** Processor is a helper thread for running the work without blocking the UI thread. */
-public class Processor extends Thread {
+public class Processor implements Runnable {
public interface Callback {
- public void onBenchmarkFinish(boolean ok);
+ void onBenchmarkFinish(boolean ok);
- public void onStatusUpdate(int testNumber, int numTests, String modelName);
+ void onStatusUpdate(int testNumber, int numTests, String modelName);
}
protected static final String TAG = "NN_BENCHMARK";
private Context mContext;
- private float mLastResult;
- private boolean mRun = true;
+ private final AtomicBoolean mRun = new AtomicBoolean(true);
+
+ volatile boolean mHasBeenStarted = false;
+ // You cannot restart a thread, so the completion flag is final
+ private final CountDownLatch mCompleted = new CountDownLatch(1);
private boolean mDoingBenchmark;
private NNTestBase mTest;
private int mTestList[];
@@ -78,18 +85,24 @@
// Method to retrieve benchmark results for instrumentation tests.
public BenchmarkResult getInstrumentationResult(
TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds)
- throws IOException {
+ throws IOException, BenchmarkException {
mTest = changeTest(mTest, t);
- return getBenchmark(warmupTimeSeconds, runTimeSeconds);
+ BenchmarkResult result = getBenchmark(warmupTimeSeconds, runTimeSeconds);
+ mTest.destroy();
+ mTest = null;
+ return result;
}
- private NNTestBase changeTest(NNTestBase oldTestBase, TestModels.TestModelEntry t) {
+ private NNTestBase changeTest(NNTestBase oldTestBase, TestModels.TestModelEntry t)
+ throws BenchmarkException {
if (oldTestBase != null) {
// Make sure we don't leak memory.
oldTestBase.destroy();
}
NNTestBase tb = t.createNNTestBase(mUseNNApi, false /* enableIntermediateTensorsDump */);
- tb.setupModel(mContext);
+ if (!tb.setupModel(mContext)) {
+ throw new BenchmarkException("Cannot initialise model");
+ }
return tb;
}
@@ -133,14 +146,10 @@
mTest.checkSdkVersion();
} catch (UnsupportedSdkException e) {
BenchmarkResult r = new BenchmarkResult(e.getMessage());
- Log.v(TAG, "Test: " + r.toString());
+ Log.v(TAG, "Unsupported SDK for test: " + r.toString());
return r;
}
- mDoingBenchmark = true;
-
- long result = 0;
-
// We run a short bit of work before starting the actual test
// this is to let any power management do its job and respond.
// For NNAPI systrace usage documentation, see
@@ -163,79 +172,114 @@
Trace.endSection();
}
- Log.v(TAG, "Test: " + r.toString());
+ Log.v(TAG, "Completed benchmark loop");
- mDoingBenchmark = false;
return r;
}
@Override
public void run() {
- while (mRun) {
- // Our loop for launching tests or benchmarks
- synchronized (this) {
- // We may have been asked to exit while waiting
- if (!mRun) return;
- }
-
- try {
- // Loop over the tests we want to benchmark
- for (int ct = 0; (ct < mTestList.length) && mRun; ct++) {
-
- // For reproducibility we wait a short time for any sporadic work
- // created by the user touching the screen to launch the test to pass.
- // Also allows for things to settle after the test changes.
- try {
- sleep(250);
- } catch (InterruptedException e) {
- }
-
- TestModels.TestModelEntry testModel =
- TestModels.modelsList().get(mTestList[ct]);
- int testNumber = ct + 1;
- mCallback.onStatusUpdate(testNumber, mTestList.length, testModel.toString());
-
- // Select the next test
- mTest = changeTest(mTest, testModel);
-
- // If the user selected the "long pause" option, wait
- if (mTogglePause) {
- for (int i = 0; (i < 100) && mRun; i++) {
- try {
- sleep(100);
- } catch (InterruptedException e) {
- }
- }
- }
-
- // Run the test
- float warmupTime = 0.3f;
- float runTime = 1.f;
- if (mToggleLong) {
- warmupTime = 2.f;
- runTime = 10.f;
- }
- mTestResults[ct] = getBenchmark(warmupTime, runTime);
+ mHasBeenStarted = true;
+ Log.d(TAG, "Processor starting");
+ try {
+ while (mRun.get()) {
+ try {
+ benchmarkAllModels();
+ } catch (IOException e) {
+ Log.e(TAG, "IOException during benchmark run", e);
+ break;
+ } catch (Throwable e) {
+ Log.e(TAG, "Error during execution", e);
+ throw e;
}
- mCallback.onBenchmarkFinish(mRun);
- } catch (IOException e) {
- Log.e(TAG, "Exception during benchmark run", e);
+
+ mCallback.onBenchmarkFinish(mRun.get());
+ }
+ } finally {
+ mCompleted.countDown();
+ }
+ }
+
+ private void benchmarkAllModels() throws IOException {
+ Log.i(TAG, String.format("Iterating through %d models", mTestList.length));
+ // Loop over the tests we want to benchmark
+ for (int ct = 0; ct < mTestList.length; ct++) {
+ if (!mRun.get()) {
+ Log.v(TAG, String.format("Asked to stop execution at model #%d", ct));
break;
}
+ // For reproducibility we wait a short time for any sporadic work
+ // created by the user touching the screen to launch the test to pass.
+ // Also allows for things to settle after the test changes.
+ try {
+ Thread.sleep(250);
+ } catch (InterruptedException ignored) {
+ Thread.currentThread().interrupt();
+ break;
+ }
+
+ TestModels.TestModelEntry testModel =
+ TestModels.modelsList().get(mTestList[ct]);
+
+ Log.i(TAG, String.format("%d/%d: '%s'", ct, mTestList.length,
+ testModel.mTestName));
+ int testNumber = ct + 1;
+ mCallback.onStatusUpdate(testNumber, mTestList.length,
+ testModel.toString());
+
+ // Select the next test
+ try {
+ mTest = changeTest(mTest, testModel);
+ } catch (BenchmarkException e) {
+ Log.w(TAG, String.format("Cannot initialise test %d: '%s', skipping", ct,
+ testModel.mTestName), e);
+ }
+
+ // If the user selected the "long pause" option, wait
+ if (mTogglePause) {
+ for (int i = 0; (i < 100) && mRun.get(); i++) {
+ try {
+ Thread.sleep(100);
+ } catch (InterruptedException ignored) {
+ Thread.currentThread().interrupt();
+ break;
+ }
+ }
+ }
+
+ // Run the test
+ float warmupTime = 0.3f;
+ float runTime = 1.f;
+ if (mToggleLong) {
+ warmupTime = 2.f;
+ runTime = 10.f;
+ }
+ Log.i(TAG, "Running test for model " + testModel.mModelName + " file "
+ + testModel.mModelFile);
+ mTestResults[ct] = getBenchmark(warmupTime, runTime);
}
}
public void exit() {
- mRun = false;
+ exitWithTimeout(-1l);
+ }
- synchronized (this) {
- notifyAll();
- }
- // exit() is called on same thread when run via dogfood BenchmarkJobService
- if (this != Thread.currentThread()) {
+ public void exitWithTimeout(long timeoutMs) {
+ mRun.set(false);
+
+ if (mHasBeenStarted) {
try {
- this.join();
+ if (timeoutMs > 0) {
+ boolean hasCompleted = mCompleted.await(timeoutMs, MILLISECONDS);
+ if (!hasCompleted) {
+ Log.w(TAG, "Exiting before execution actually completed");
+ }
+ } else {
+ mCompleted.await();
+ }
} catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ Log.w(TAG, "Interrupted while waiting for Processor to complete", e);
}
}
diff --git a/src/com/android/nn/benchmark/core/TestModels.java b/src/com/android/nn/benchmark/core/TestModels.java
index 95dfff5..7dad749 100644
--- a/src/com/android/nn/benchmark/core/TestModels.java
+++ b/src/com/android/nn/benchmark/core/TestModels.java
@@ -16,9 +16,9 @@
package com.android.nn.benchmark.core;
-import java.util.List;
import java.util.ArrayList;
-import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicReference;
/** Information about available benchmarking models */
public class TestModels {
@@ -88,12 +88,13 @@
}
}
- static private List<TestModelEntry> sTestModelEntryList = new ArrayList<>();
- static private volatile boolean sTestModelEntryListFrozen = false;
+ static private final List<TestModelEntry> sTestModelEntryList = new ArrayList<>();
+ static private final AtomicReference<List<TestModelEntry>> frozenEntries = new AtomicReference<>(null);
+
/** Add new benchmark model. */
static public void registerModel(TestModelEntry model) {
- if (sTestModelEntryListFrozen) {
+ if (frozenEntries.get() != null) {
throw new IllegalStateException("Can't register new models after its list is frozen");
}
sTestModelEntryList.add(model);
@@ -104,16 +105,8 @@
* If this method was called at least once, then it's impossible to register new models.
*/
static public List<TestModelEntry> modelsList() {
- if (!sTestModelEntryListFrozen) {
- // If this method was called once, make models list unmodifiable
- synchronized (TestModels.class) {
- if (!sTestModelEntryListFrozen) {
- sTestModelEntryList = Collections.unmodifiableList(sTestModelEntryList);
- sTestModelEntryListFrozen = true;
- }
- }
- }
- return sTestModelEntryList;
+ frozenEntries.compareAndSet(null, sTestModelEntryList);
+ return frozenEntries.get();
}
/** Fetch model by its name. */