Refactor code to prepare for parallel tests am: 2aaa3f84b3 am: 699747fbf7

Change-Id: I38c38fed456634478a2d31dfa68a43d300ba13d4
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..1519f47
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,4 @@
+*.iml
+**/gen/*
+**/.idea/*
+
diff --git a/dogfood/src/com/android/nn/dogfood/BenchmarkJobService.java b/dogfood/src/com/android/nn/dogfood/BenchmarkJobService.java
index 4334d84..21ce236 100644
--- a/dogfood/src/com/android/nn/dogfood/BenchmarkJobService.java
+++ b/dogfood/src/com/android/nn/dogfood/BenchmarkJobService.java
@@ -28,12 +28,13 @@
 import androidx.core.app.NotificationManagerCompat;
 
 import com.android.nn.benchmark.core.BenchmarkResult;
-import com.android.nn.benchmark.core.NNTestBase;
 import com.android.nn.benchmark.core.Processor;
 import com.android.nn.benchmark.core.TestModels;
 
 import java.util.List;
 import java.util.Random;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
 
 /** Regularly runs a random selection of the NN API benchmark models */
 public class BenchmarkJobService extends JobService implements Processor.Callback {
@@ -52,6 +53,7 @@
 
     private static int DOGFOOD_MODELS_PER_RUN = 20;
     private BenchmarkResult mTestResults[];
+    private final ExecutorService processorRunner = Executors.newSingleThreadExecutor();
 
 
     @Override
@@ -74,11 +76,10 @@
     }
 
     public void doBenchmark() {
-
         mProcessor = new Processor(this, this, randomModelList());
         mProcessor.setUseNNApi(true);
         mProcessor.setToggleLong(true);
-        mProcessor.start();
+        processorRunner.submit(mProcessor);
     }
 
     public void onBenchmarkFinish(boolean ok) {
diff --git a/res/values/strings.xml b/res/values/strings.xml
index 180e0cb..aab5c97 100644
--- a/res/values/strings.xml
+++ b/res/values/strings.xml
@@ -33,7 +33,7 @@
     <string name="ok">Ok</string>
     <string name="cancel">Cancel</string>
     <string name="settings">settings</string>
-    <string-array
+  <string-array
         name="settings_array">
         <item>Run each test longer, 10 seconds</item>
         <item>Pause 10 seconds between tests</item>
diff --git a/src/com/android/nn/benchmark/app/BenchmarkTestBase.java b/src/com/android/nn/benchmark/app/BenchmarkTestBase.java
index 5aaa056..a89626f 100644
--- a/src/com/android/nn/benchmark/app/BenchmarkTestBase.java
+++ b/src/com/android/nn/benchmark/app/BenchmarkTestBase.java
@@ -23,7 +23,6 @@
 import android.content.Intent;
 import android.content.IntentFilter;
 import android.os.BatteryManager;
-import android.os.Bundle;
 import android.os.Trace;
 import android.test.ActivityInstrumentationTestCase2;
 import android.util.Log;
@@ -35,6 +34,8 @@
 import com.android.nn.benchmark.core.TestModels;
 import com.android.nn.benchmark.core.TestModels.TestModelEntry;
 
+import java.util.concurrent.CountDownLatch;
+
 import org.junit.After;
 import org.junit.Before;
 import org.junit.runner.RunWith;
@@ -51,6 +52,7 @@
  */
 @RunWith(Parameterized.class)
 public class BenchmarkTestBase extends ActivityInstrumentationTestCase2<NNBenchmark> {
+
     // Only run 1 iteration now to fit the MediumTest time requirement.
     // One iteration means running the tests continuous for 1s.
     private NNBenchmark mActivity;
@@ -95,33 +97,32 @@
     protected void waitUntilCharged() {
         Log.v(NNBenchmark.TAG, "Waiting for the device to charge");
 
-        Object lock = new Object();
+        final CountDownLatch chargedLatch = new CountDownLatch(1);
         BroadcastReceiver receiver = new BroadcastReceiver() {
             @Override
             public void onReceive(Context context, Intent intent) {
                 int level = intent.getIntExtra(BatteryManager.EXTRA_LEVEL, -1);
                 int scale = intent.getIntExtra(BatteryManager.EXTRA_SCALE, -1);
-                int percentage  = level * 100 / scale;
+                int percentage = level * 100 / scale;
                 Log.v(NNBenchmark.TAG, "Battery level: " + percentage + "%");
 
                 int status = intent.getIntExtra(BatteryManager.EXTRA_STATUS, -1);
                 if (status == BatteryManager.BATTERY_STATUS_FULL) {
-                    synchronized (lock) {
-                        lock.notify();
-                    }
+                    chargedLatch.countDown();
                 } else if (status != BatteryManager.BATTERY_STATUS_CHARGING) {
-                    Log.e(NNBenchmark.TAG, "Device is not charging");
+                    Log.e(NNBenchmark.TAG,
+                            String.format("Device is not charging, status is %d", status));
                 }
             }
         };
 
         mActivity.registerReceiver(receiver, new IntentFilter(Intent.ACTION_BATTERY_CHANGED));
-        synchronized (lock) {
-            try {
-                lock.wait();
-            } catch (InterruptedException e) {
-            }
+        try {
+            chargedLatch.await();
+        } catch (InterruptedException ignored) {
+            Thread.currentThread().interrupt();
         }
+
         mActivity.unregisterReceiver(receiver);
     }
 
@@ -139,34 +140,49 @@
         super.tearDown();
     }
 
-    class TestAction implements Runnable {
-        TestModelEntry mTestModel;
+    interface Joinable extends Runnable {
+        // Syncrhonises the caller with the completion of the current action
+        void join();
+    }
+
+    class TestAction implements Joinable {
+
+        private final TestModelEntry mTestModel;
+        private final float mWarmupTimeSeconds;
+        private final float mRunTimeSeconds;
+        private final CountDownLatch actionComplete;
+
         BenchmarkResult mResult;
-        float mWarmupTimeSeconds;
-        float mRunTimeSeconds;
         Throwable mException;
 
-        public TestAction(TestModelEntry testName) {
-            mTestModel = testName;
-        }
         public TestAction(TestModelEntry testName, float warmupTimeSeconds, float runTimeSeconds) {
             mTestModel = testName;
             mWarmupTimeSeconds = warmupTimeSeconds;
             mRunTimeSeconds = runTimeSeconds;
+            actionComplete = new CountDownLatch(1);
         }
 
         public void run() {
+            Log.v(NNBenchmark.TAG, String.format(
+                    "Starting benchmark for test '%s' running for at least %f seconds",
+                    mTestModel.mTestName,
+                    mRunTimeSeconds));
             try {
-                mResult = mActivity.mProcessor.getInstrumentationResult(
-                    mTestModel, mWarmupTimeSeconds, mRunTimeSeconds);
-            } catch (IOException e) {
+                mResult = mActivity.runSynchronously(
+                        mTestModel, mWarmupTimeSeconds, mRunTimeSeconds);
+                Log.v(NNBenchmark.TAG,
+                        String.format("Benchmark for test '%s' is: %s", mTestModel, mResult));
+            } catch (BenchmarkException | IOException e) {
                 mException = e;
-                e.printStackTrace();
-            }
-            Log.v(NNBenchmark.TAG,
-                    "Benchmark for test \"" + mTestModel.toString() + "\" is: " + mResult);
-            synchronized (this) {
-                this.notify();
+                Log.e(NNBenchmark.TAG,
+                        String.format("Error running Benchmark for test '%s'", mTestModel), e);
+            } catch (Throwable e) {
+                mException = e;
+                Log.e(NNBenchmark.TAG,
+                        String.format("Failure running Benchmark for test '%s'!!", mTestModel), e);
+                throw e;
+            } finally {
+                actionComplete.countDown();
             }
         }
 
@@ -176,20 +192,23 @@
             }
             return mResult;
         }
+
+        @Override
+        public void join() {
+            try {
+                actionComplete.await();
+            } catch (InterruptedException e) {
+                Thread.currentThread().interrupt();
+                Log.v(NNBenchmark.TAG, "Interrupted while waiting for action running", e);
+            }
+        }
     }
 
     // Set the benchmark thread to run on ui thread
     // Synchronized the thread such that the test will wait for the benchmark thread to finish
-    public void runOnUiThread(Runnable action) {
-        synchronized (action) {
-            mActivity.runOnUiThread(action);
-            try {
-                action.wait();
-            } catch (InterruptedException e) {
-                Log.v(NNBenchmark.TAG, "waiting for action running on UI thread is interrupted: " +
-                        e.toString());
-            }
-        }
+    public void runOnUiThread(Joinable action) {
+        mActivity.runOnUiThread(action);
+        action.join();
     }
 
     public void runTest(TestAction ta, String testName) {
diff --git a/src/com/android/nn/benchmark/app/NNBenchmark.java b/src/com/android/nn/benchmark/app/NNBenchmark.java
index a45e043..2ea4478 100644
--- a/src/com/android/nn/benchmark/app/NNBenchmark.java
+++ b/src/com/android/nn/benchmark/app/NNBenchmark.java
@@ -19,33 +19,38 @@
 import android.app.Activity;
 import android.content.Intent;
 import android.os.Bundle;
+import android.util.Log;
 import android.view.WindowManager;
 import android.widget.TextView;
-
+import com.android.nn.benchmark.core.BenchmarkException;
 import com.android.nn.benchmark.core.BenchmarkResult;
 import com.android.nn.benchmark.core.Processor;
+import com.android.nn.benchmark.core.TestModels.TestModelEntry;
+import java.io.IOException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
 
 public class NNBenchmark extends Activity implements Processor.Callback {
-    protected static final String TAG = "NN_BENCHMARK";
+    public static final String TAG = "NN_BENCHMARK";
 
     public static final String EXTRA_ENABLE_LONG = "enable long";
     public static final String EXTRA_ENABLE_PAUSE = "enable pause";
     public static final String EXTRA_DISABLE_NNAPI = "disable NNAPI";
-    public static final String EXTRA_DEMO = "demo";
     public static final String EXTRA_TESTS = "tests";
 
     public static final String EXTRA_RESULTS_TESTS = "tests";
     public static final String EXTRA_RESULTS_RESULTS = "results";
 
     private int mTestList[];
-    private BenchmarkResult mTestResults[];
+
+    private Processor mProcessor;
+    private final ExecutorService executorService = Executors.newSingleThreadExecutor();
 
     private TextView mTextView;
 
     // Initialize the parameters for Instrumentation tests.
     protected void prepareInstrumentationTest() {
         mTestList = new int[1];
-        mTestResults = new BenchmarkResult[1];
         mProcessor = new Processor(this, this, mTestList);
     }
 
@@ -57,9 +62,6 @@
         mProcessor.setCompleteInputSet(completeInputSet);
     }
 
-    private boolean mDoingBenchmark;
-    public Processor mProcessor;
-
     @Override
     protected void onCreate(Bundle savedInstanceState) {
         super.onCreate(savedInstanceState);
@@ -74,7 +76,8 @@
     protected void onPause() {
         super.onPause();
         if (mProcessor != null) {
-            mProcessor.exit();
+            mProcessor.exitWithTimeout(30000l);
+            mProcessor = null;
         }
     }
 
@@ -104,12 +107,15 @@
         super.onResume();
         Intent i = getIntent();
         mTestList = i.getIntArrayExtra(EXTRA_TESTS);
-        mProcessor = new Processor(this, this, mTestList);
-        mProcessor.setToggleLong(i.getBooleanExtra(EXTRA_ENABLE_LONG, false));
-        mProcessor.setTogglePause(i.getBooleanExtra(EXTRA_ENABLE_PAUSE, false));
-        mProcessor.setUseNNApi(!i.getBooleanExtra(EXTRA_DISABLE_NNAPI, false));
-        if (mTestList != null) {
-            mProcessor.start();
+        if (mTestList != null && mTestList.length > 0) {
+            Log.v(TAG, String.format("Starting benchmark with %d test", mTestList.length));
+            mProcessor = new Processor(this, this, mTestList);
+            mProcessor.setToggleLong(i.getBooleanExtra(EXTRA_ENABLE_LONG, false));
+            mProcessor.setTogglePause(i.getBooleanExtra(EXTRA_ENABLE_PAUSE, false));
+            mProcessor.setUseNNApi(!i.getBooleanExtra(EXTRA_DISABLE_NNAPI, false));
+            executorService.submit(mProcessor);
+        } else {
+            Log.v(TAG, "No test to run, doing nothing");
         }
     }
 
@@ -117,4 +123,9 @@
     protected void onDestroy() {
         super.onDestroy();
     }
+
+    public BenchmarkResult runSynchronously(TestModelEntry testModel,
+        float warmupTimeSeconds, float runTimeSeconds) throws IOException, BenchmarkException {
+        return mProcessor.getInstrumentationResult(testModel, warmupTimeSeconds, runTimeSeconds);
+    }
 }
diff --git a/src/com/android/nn/benchmark/app/NNScoringTest.java b/src/com/android/nn/benchmark/app/NNScoringTest.java
index 1c339cf..4932eb1 100644
--- a/src/com/android/nn/benchmark/app/NNScoringTest.java
+++ b/src/com/android/nn/benchmark/app/NNScoringTest.java
@@ -36,6 +36,7 @@
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
+import android.util.Log;
 
 /**
  * Tests that run all models/datasets/backend that are required for scoring the device.
@@ -61,17 +62,15 @@
         super.prepareTest();
     }
 
-    @Test
-    @LargeTest
-    public void testTFLite() throws IOException {
+    private void test(boolean useNnapi) throws IOException {
         if (!TestExternalStorageActivity.testWriteExternalStorage(getActivity(), false)) {
             throw new IOException("No permission to store results in external storage");
         }
 
-        setUseNNApi(false);
+        setUseNNApi(useNnapi);
         setCompleteInputSet(true);
         TestAction ta = new TestAction(mModel, WARMUP_REPEATABLE_SECONDS,
-                COMPLETE_SET_TIMEOUT_SECOND);
+            COMPLETE_SET_TIMEOUT_SECOND);
         runTest(ta, mModel.getTestName());
 
         try (CSVWriter writer = new CSVWriter(getLocalCSVFile())) {
@@ -81,21 +80,14 @@
 
     @Test
     @LargeTest
+    public void testTFLite() throws IOException {
+        test(false);
+    }
+
+    @Test
+    @LargeTest
     public void testNNAPI() throws IOException {
-        if (!TestExternalStorageActivity.testWriteExternalStorage(getActivity(), false)) {
-            throw new IOException("No permission to store results in external storage");
-        }
-
-        setUseNNApi(true);
-        setCompleteInputSet(true);
-        TestAction ta = new TestAction(mModel, WARMUP_REPEATABLE_SECONDS,
-                COMPLETE_SET_TIMEOUT_SECOND);
-        runTest(ta, mModel.getTestName());
-
-
-        try (CSVWriter writer = new CSVWriter(getLocalCSVFile())) {
-            writer.write(ta.getBenchmark());
-        }
+        test(true);
     }
 
     public static File getLocalCSVFile() {
diff --git a/src/com/android/nn/benchmark/core/NNTestBase.java b/src/com/android/nn/benchmark/core/NNTestBase.java
index 579b4a2..21f8711 100644
--- a/src/com/android/nn/benchmark/core/NNTestBase.java
+++ b/src/com/android/nn/benchmark/core/NNTestBase.java
@@ -262,12 +262,13 @@
         Pair<List<InferenceInOutSequence>, List<InferenceResult>> result =
                 runBenchmark(ios, totalSequenceInferencesCount, timeoutSec,
                         flags);
-        if (result.second.size() != extpectedResults ) {
+        if (result.second.size() != extpectedResults) {
             // We reached a timeout or failed to evaluate whole set for other reason, abort.
-            throw new IllegalStateException(
-                    "Failed to evaluate complete input set, expected: "
-                            + extpectedResults +
-                            ", received: " + result.second.size());
+            final String errorMsg = "Failed to evaluate complete input set, expected: "
+                    + extpectedResults +
+                    ", received: " + result.second.size();
+            Log.w(TAG, errorMsg);
+            throw new IllegalStateException(errorMsg);
         }
         return result;
     }
@@ -283,7 +284,7 @@
         }
         List<InferenceResult> resultList = new ArrayList<>();
         if (!runBenchmark(mModelHandle, inOutList, resultList, inferencesSeqMaxCount,
-                    timeoutSec, flags)) {
+                timeoutSec, flags)) {
             throw new BenchmarkException("Failed to run benchmark");
         }
         return new Pair<List<InferenceInOutSequence>, List<InferenceResult>>(
@@ -303,21 +304,18 @@
         String modelAssetName = mModelFile + ".tflite";
         AssetManager assetManager = mContext.getAssets();
         try {
-            InputStream in = assetManager.open(modelAssetName);
-
             outFileName = mContext.getCacheDir().getAbsolutePath() + "/" + modelAssetName;
             File outFile = new File(outFileName);
-            OutputStream out = new FileOutputStream(outFile);
 
-            byte[] buffer = new byte[1024];
-            int read;
-            while ((read = in.read(buffer)) != -1) {
-                out.write(buffer, 0, read);
+            try (InputStream in = assetManager.open(modelAssetName);
+                 FileOutputStream out = new FileOutputStream(outFile)) {
+
+                byte[] byteBuffer = new byte[1024];
+                int readBytes = -1;
+                while ((readBytes = in.read(byteBuffer)) != -1) {
+                    out.write(byteBuffer, 0, readBytes);
+                }
             }
-            out.flush();
-
-            in.close();
-            out.close();
         } catch (IOException e) {
             Log.e(TAG, "Failed to copy asset file: " + modelAssetName, e);
             return null;
diff --git a/src/com/android/nn/benchmark/core/Processor.java b/src/com/android/nn/benchmark/core/Processor.java
index 1aa6008..0017082 100644
--- a/src/com/android/nn/benchmark/core/Processor.java
+++ b/src/com/android/nn/benchmark/core/Processor.java
@@ -16,6 +16,8 @@
 
 package com.android.nn.benchmark.core;
 
+import static java.util.concurrent.TimeUnit.MILLISECONDS;
+
 import android.content.Context;
 import android.os.Trace;
 import android.util.Log;
@@ -23,21 +25,26 @@
 
 import java.io.IOException;
 import java.util.List;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.atomic.AtomicBoolean;
 
 /** Processor is a helper thread for running the work without blocking the UI thread. */
-public class Processor extends Thread {
+public class Processor implements Runnable {
 
     public interface Callback {
-        public void onBenchmarkFinish(boolean ok);
+        void onBenchmarkFinish(boolean ok);
 
-        public void onStatusUpdate(int testNumber, int numTests, String modelName);
+        void onStatusUpdate(int testNumber, int numTests, String modelName);
     }
 
     protected static final String TAG = "NN_BENCHMARK";
     private Context mContext;
 
-    private float mLastResult;
-    private boolean mRun = true;
+    private final AtomicBoolean mRun = new AtomicBoolean(true);
+
+    volatile boolean mHasBeenStarted = false;
+    // You cannot restart a thread, so the completion flag is final
+    private final CountDownLatch mCompleted = new CountDownLatch(1);
     private boolean mDoingBenchmark;
     private NNTestBase mTest;
     private int mTestList[];
@@ -78,18 +85,24 @@
     // Method to retrieve benchmark results for instrumentation tests.
     public BenchmarkResult getInstrumentationResult(
             TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds)
-            throws IOException {
+            throws IOException, BenchmarkException {
         mTest = changeTest(mTest, t);
-        return getBenchmark(warmupTimeSeconds, runTimeSeconds);
+        BenchmarkResult result = getBenchmark(warmupTimeSeconds, runTimeSeconds);
+        mTest.destroy();
+        mTest = null;
+        return result;
     }
 
-    private NNTestBase changeTest(NNTestBase oldTestBase, TestModels.TestModelEntry t) {
+    private NNTestBase changeTest(NNTestBase oldTestBase, TestModels.TestModelEntry t)
+            throws BenchmarkException {
         if (oldTestBase != null) {
             // Make sure we don't leak memory.
             oldTestBase.destroy();
         }
         NNTestBase tb = t.createNNTestBase(mUseNNApi, false /* enableIntermediateTensorsDump */);
-        tb.setupModel(mContext);
+        if (!tb.setupModel(mContext)) {
+            throw new BenchmarkException("Cannot initialise model");
+        }
         return tb;
     }
 
@@ -133,14 +146,10 @@
             mTest.checkSdkVersion();
         } catch (UnsupportedSdkException e) {
             BenchmarkResult r = new BenchmarkResult(e.getMessage());
-            Log.v(TAG, "Test: " + r.toString());
+            Log.v(TAG, "Unsupported SDK for test: " + r.toString());
             return r;
         }
 
-        mDoingBenchmark = true;
-
-        long result = 0;
-
         // We run a short bit of work before starting the actual test
         // this is to let any power management do its job and respond.
         // For NNAPI systrace usage documentation, see
@@ -163,79 +172,114 @@
             Trace.endSection();
         }
 
-        Log.v(TAG, "Test: " + r.toString());
+        Log.v(TAG, "Completed benchmark loop");
 
-        mDoingBenchmark = false;
         return r;
     }
 
     @Override
     public void run() {
-        while (mRun) {
-            // Our loop for launching tests or benchmarks
-            synchronized (this) {
-                // We may have been asked to exit while waiting
-                if (!mRun) return;
-            }
-
-            try {
-                // Loop over the tests we want to benchmark
-                for (int ct = 0; (ct < mTestList.length) && mRun; ct++) {
-
-                    // For reproducibility we wait a short time for any sporadic work
-                    // created by the user touching the screen to launch the test to pass.
-                    // Also allows for things to settle after the test changes.
-                    try {
-                        sleep(250);
-                    } catch (InterruptedException e) {
-                    }
-
-                    TestModels.TestModelEntry testModel =
-                            TestModels.modelsList().get(mTestList[ct]);
-                    int testNumber = ct + 1;
-                    mCallback.onStatusUpdate(testNumber, mTestList.length, testModel.toString());
-
-                    // Select the next test
-                    mTest = changeTest(mTest, testModel);
-
-                    // If the user selected the "long pause" option, wait
-                    if (mTogglePause) {
-                        for (int i = 0; (i < 100) && mRun; i++) {
-                            try {
-                                sleep(100);
-                            } catch (InterruptedException e) {
-                            }
-                        }
-                    }
-
-                    // Run the test
-                    float warmupTime = 0.3f;
-                    float runTime = 1.f;
-                    if (mToggleLong) {
-                        warmupTime = 2.f;
-                        runTime = 10.f;
-                    }
-                    mTestResults[ct] = getBenchmark(warmupTime, runTime);
+        mHasBeenStarted = true;
+        Log.d(TAG, "Processor starting");
+        try {
+            while (mRun.get()) {
+                try {
+                    benchmarkAllModels();
+                } catch (IOException e) {
+                    Log.e(TAG, "IOException during benchmark run", e);
+                    break;
+                } catch (Throwable e) {
+                    Log.e(TAG, "Error during execution", e);
+                    throw e;
                 }
-                mCallback.onBenchmarkFinish(mRun);
-            } catch (IOException e) {
-                Log.e(TAG, "Exception during benchmark run", e);
+
+                mCallback.onBenchmarkFinish(mRun.get());
+            }
+        } finally {
+            mCompleted.countDown();
+        }
+    }
+
+    private void benchmarkAllModels() throws IOException {
+        Log.i(TAG, String.format("Iterating through %d models", mTestList.length));
+        // Loop over the tests we want to benchmark
+        for (int ct = 0; ct < mTestList.length; ct++) {
+            if (!mRun.get()) {
+                Log.v(TAG, String.format("Asked to stop execution at model #%d", ct));
                 break;
             }
+            // For reproducibility we wait a short time for any sporadic work
+            // created by the user touching the screen to launch the test to pass.
+            // Also allows for things to settle after the test changes.
+            try {
+                Thread.sleep(250);
+            } catch (InterruptedException ignored) {
+                Thread.currentThread().interrupt();
+                break;
+            }
+
+            TestModels.TestModelEntry testModel =
+                    TestModels.modelsList().get(mTestList[ct]);
+
+            Log.i(TAG, String.format("%d/%d: '%s'", ct, mTestList.length,
+                    testModel.mTestName));
+            int testNumber = ct + 1;
+            mCallback.onStatusUpdate(testNumber, mTestList.length,
+                    testModel.toString());
+
+            // Select the next test
+            try {
+                mTest = changeTest(mTest, testModel);
+            } catch (BenchmarkException e) {
+                Log.w(TAG, String.format("Cannot initialise test %d: '%s', skipping", ct,
+                        testModel.mTestName), e);
+            }
+
+            // If the user selected the "long pause" option, wait
+            if (mTogglePause) {
+                for (int i = 0; (i < 100) && mRun.get(); i++) {
+                    try {
+                        Thread.sleep(100);
+                    } catch (InterruptedException ignored) {
+                        Thread.currentThread().interrupt();
+                        break;
+                    }
+                }
+            }
+
+            // Run the test
+            float warmupTime = 0.3f;
+            float runTime = 1.f;
+            if (mToggleLong) {
+                warmupTime = 2.f;
+                runTime = 10.f;
+            }
+            Log.i(TAG, "Running test for model " + testModel.mModelName + " file "
+                    + testModel.mModelFile);
+            mTestResults[ct] = getBenchmark(warmupTime, runTime);
         }
     }
 
     public void exit() {
-        mRun = false;
+        exitWithTimeout(-1l);
+    }
 
-        synchronized (this) {
-            notifyAll();
-        }
-        // exit() is called on same thread when run via dogfood BenchmarkJobService
-        if (this != Thread.currentThread()) {
+    public void exitWithTimeout(long timeoutMs) {
+        mRun.set(false);
+
+        if (mHasBeenStarted) {
             try {
-                this.join();
+                if (timeoutMs > 0) {
+                    boolean hasCompleted = mCompleted.await(timeoutMs, MILLISECONDS);
+                    if (!hasCompleted) {
+                        Log.w(TAG, "Exiting before execution actually completed");
+                    }
+                } else {
+                    mCompleted.await();
+                }
             } catch (InterruptedException e) {
+                Thread.currentThread().interrupt();
+                Log.w(TAG, "Interrupted while waiting for Processor to complete", e);
             }
         }
 
diff --git a/src/com/android/nn/benchmark/core/TestModels.java b/src/com/android/nn/benchmark/core/TestModels.java
index 95dfff5..7dad749 100644
--- a/src/com/android/nn/benchmark/core/TestModels.java
+++ b/src/com/android/nn/benchmark/core/TestModels.java
@@ -16,9 +16,9 @@
 
 package com.android.nn.benchmark.core;
 
-import java.util.List;
 import java.util.ArrayList;
-import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicReference;
 
 /** Information about available benchmarking models */
 public class TestModels {
@@ -88,12 +88,13 @@
         }
     }
 
-    static private List<TestModelEntry> sTestModelEntryList = new ArrayList<>();
-    static private volatile boolean sTestModelEntryListFrozen = false;
+    static private final List<TestModelEntry> sTestModelEntryList = new ArrayList<>();
+    static private final AtomicReference<List<TestModelEntry>> frozenEntries = new AtomicReference<>(null);
+
 
     /** Add new benchmark model. */
     static public void registerModel(TestModelEntry model) {
-        if (sTestModelEntryListFrozen) {
+        if (frozenEntries.get() != null) {
             throw new IllegalStateException("Can't register new models after its list is frozen");
         }
         sTestModelEntryList.add(model);
@@ -104,16 +105,8 @@
      * If this method was called at least once, then it's impossible to register new models.
      */
     static public List<TestModelEntry> modelsList() {
-        if (!sTestModelEntryListFrozen) {
-            // If this method was called once, make models list unmodifiable
-            synchronized (TestModels.class) {
-                if (!sTestModelEntryListFrozen) {
-                    sTestModelEntryList = Collections.unmodifiableList(sTestModelEntryList);
-                    sTestModelEntryListFrozen = true;
-                }
-            }
-        }
-        return sTestModelEntryList;
+        frozenEntries.compareAndSet(null, sTestModelEntryList);
+        return frozenEntries.get();
     }
 
     /** Fetch model by its name. */