Merge "Add test for performance degradation"
diff --git a/AndroidManifest.xml b/AndroidManifest.xml
index 4ac42c5..ffda154 100644
--- a/AndroidManifest.xml
+++ b/AndroidManifest.xml
@@ -56,6 +56,11 @@
<action android:name="android.intent.action.MAIN" />
</intent-filter>
</activity>
+ <activity android:name="com.android.nn.crashtest.app.NNPerformanceDegradationTestActivity">
+ <intent-filter>
+ <action android:name="android.intent.action.MAIN" />
+ </intent-filter>
+ </activity>
<service android:name="com.android.nn.crashtest.core.OutOfProcessCrashTestService"
android:process=":CrashTest" />
diff --git a/README.txt b/README.txt
index 99042b9..f9e2c95 100644
--- a/README.txt
+++ b/README.txt
@@ -71,4 +71,7 @@
* model-load-random-stress: test compiling a large set of randomly generated models
-* inference-random-stress: test running a large set of randomly generated models
\ No newline at end of file
+* inference-random-stress: test running a large set of randomly generated models
+
+* performance-degradation-stress: verifies that accelerator inference speed is not degrading over
+a certain threshold when running concurrent workload
\ No newline at end of file
diff --git a/build_and_run_benchmark.sh b/build_and_run_benchmark.sh
index b1ccec7..f3dbb82 100755
--- a/build_and_run_benchmark.sh
+++ b/build_and_run_benchmark.sh
@@ -100,13 +100,17 @@
APP="$CRASH_TEST_APP"
CLASS=com.android.nn.crashtest.app.NNRandomGraphExecutionTest
;;
+ performance-degradation-stress)
+ APP="$CRASH_TEST_APP"
+ CLASS=com.android.nn.crashtest.app.NNPerformanceDegradationTest
+ ;;
*)
echo "Unknown execution mode: $1"
echo "Known modes: scoring (default), inference-stress, model-loading-stress, " \
"parallel-inference-stress, parallel-inference-stress-in-process, " \
"client-early-termination-stress, multi-process-inference-stress, " \
"multi-process-model-load-stress memory-mapped-model-load-stress, " \
- "model-load-random-stress inference-random-stress"
+ "model-load-random-stress, inference-random-stress, performance-degradation-stress"
exit 1
;;
esac
diff --git a/crashtest/Android.mk b/crashtest/Android.mk
index 20890b5..45a681c 100644
--- a/crashtest/Android.mk
+++ b/crashtest/Android.mk
@@ -29,7 +29,8 @@
$(call all-java-files-under, ../src/com/android/nn/benchmark/evaluators) \
$(call all-java-files-under, ../src/com/android/nn/benchmark/imageprocessors) \
$(call all-java-files-under, ../src/com/android/nn/benchmark/util) \
- $(call all-java-files-under, ../src/com/android/nn/crashtest/core)
+ $(call all-java-files-under, ../src/com/android/nn/crashtest/core) \
+ ../src/com/android/nn/crashtest/app/AcceleratorSpecificTestSupport.java
LOCAL_JNI_SHARED_LIBRARIES := libnnbenchmark_jni
diff --git a/src/com/android/nn/benchmark/core/Processor.java b/src/com/android/nn/benchmark/core/Processor.java
index 4e799c4..778e5d0 100644
--- a/src/com/android/nn/benchmark/core/Processor.java
+++ b/src/com/android/nn/benchmark/core/Processor.java
@@ -129,12 +129,14 @@
}
// Method to retrieve benchmark results for instrumentation tests.
+ // Returns null if the processor is configured to run compilation only
public BenchmarkResult getInstrumentationResult(
TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds)
throws IOException, BenchmarkException {
mTest = changeTest(mTest, t);
try {
- BenchmarkResult result = getBenchmark(warmupTimeSeconds, runTimeSeconds);
+ BenchmarkResult result = mRunModelCompilationOnly ? null : getBenchmark(warmupTimeSeconds,
+ runTimeSeconds);
return result;
} finally {
mTest.destroy();
diff --git a/src/com/android/nn/benchmark/core/TestModelsListLoader.java b/src/com/android/nn/benchmark/core/TestModelsListLoader.java
index 61dc278..3e25d69 100644
--- a/src/com/android/nn/benchmark/core/TestModelsListLoader.java
+++ b/src/com/android/nn/benchmark/core/TestModelsListLoader.java
@@ -17,6 +17,7 @@
package com.android.nn.benchmark.core;
import android.content.res.AssetManager;
+import android.util.Log;
import org.json.JSONArray;
import org.json.JSONException;
@@ -29,6 +30,7 @@
/** Helper class to register test model definitions from assets data */
public class TestModelsListLoader {
+ private static final String TAG = "NN_BENCHMARK";
/**
* Parse list of models in form of json data.
@@ -175,8 +177,10 @@
parseJSONModelsList(readAssetsFileAsString(
assetManager.open(MODELS_LIST_ROOT + "/" + file)));
} catch (JSONException e) {
+ Log.e(TAG, "error reading json model list", e);
throw new IOException("JSON error in " + file, e);
} catch (Exception e) {
+ Log.e(TAG, "error parsing json model list", e);
// Wrap exception to add a filename to it
throw new IOException("Error while parsing " + file, e);
}
diff --git a/src/com/android/nn/crashtest/app/AcceleratorSpecificTestSupport.java b/src/com/android/nn/crashtest/app/AcceleratorSpecificTestSupport.java
index 87b6180..e4da19a 100644
--- a/src/com/android/nn/crashtest/app/AcceleratorSpecificTestSupport.java
+++ b/src/com/android/nn/crashtest/app/AcceleratorSpecificTestSupport.java
@@ -19,7 +19,7 @@
import android.content.Context;
import android.util.Log;
-import androidx.test.platform.app.InstrumentationRegistry;
+import androidx.test.InstrumentationRegistry;
import com.android.nn.benchmark.core.BenchmarkException;
import com.android.nn.benchmark.core.BenchmarkResult;
@@ -40,7 +40,7 @@
public interface AcceleratorSpecificTestSupport {
String TAG = "AcceleratorTest";
- default Optional<TestModels.TestModelEntry> findTestModelRunningOnAccelerator(
+ static Optional<TestModels.TestModelEntry> findTestModelRunningOnAccelerator(
Context context, String acceleratorName) throws NnApiDelegationFailure {
for (TestModels.TestModelEntry model : TestModels.modelsList()) {
if (Processor.isTestModelSupportedByAccelerator(context, model, acceleratorName)) {
@@ -50,6 +50,17 @@
return Optional.empty();
}
+ static List<TestModels.TestModelEntry> findAllTestModelsRunningOnAccelerator(
+ Context context, String acceleratorName) throws NnApiDelegationFailure {
+ List<TestModels.TestModelEntry> result = new ArrayList<>();
+ for (TestModels.TestModelEntry model : TestModels.modelsList()) {
+ if (Processor.isTestModelSupportedByAccelerator(context, model, acceleratorName)) {
+ result.add(model);
+ }
+ }
+ return result;
+ }
+
default long ramdomInRange(long min, long max) {
return min + (long) (Math.random() * (max - min));
}
diff --git a/src/com/android/nn/crashtest/app/NNClientEarlyTerminationTest.java b/src/com/android/nn/crashtest/app/NNClientEarlyTerminationTest.java
index c3a1f61..b8d4188 100644
--- a/src/com/android/nn/crashtest/app/NNClientEarlyTerminationTest.java
+++ b/src/com/android/nn/crashtest/app/NNClientEarlyTerminationTest.java
@@ -94,7 +94,7 @@
final NNParallelTestActivity activity = getActivity();
Optional<TestModels.TestModelEntry> modelForLivenessTest =
- findTestModelRunningOnAccelerator(activity, mAcceleratorName);
+ AcceleratorSpecificTestSupport.findTestModelRunningOnAccelerator(activity, mAcceleratorName);
assertTrue("No model available to be run on accelerator " + mAcceleratorName,
modelForLivenessTest.isPresent());
diff --git a/src/com/android/nn/crashtest/app/NNMemoryMappedModelCompilationTest.java b/src/com/android/nn/crashtest/app/NNMemoryMappedModelCompilationTest.java
index b859ece..b652396 100644
--- a/src/com/android/nn/crashtest/app/NNMemoryMappedModelCompilationTest.java
+++ b/src/com/android/nn/crashtest/app/NNMemoryMappedModelCompilationTest.java
@@ -98,7 +98,7 @@
final NNParallelTestActivity activity = getActivity();
Optional<TestModels.TestModelEntry> modelForLivenessTest =
- findTestModelRunningOnAccelerator(activity, mAcceleratorName);
+ AcceleratorSpecificTestSupport.findTestModelRunningOnAccelerator(activity, mAcceleratorName);
assertTrue("No model available to be run on accelerator " + mAcceleratorName,
modelForLivenessTest.isPresent());
diff --git a/src/com/android/nn/crashtest/app/NNMultipleProcessTest.java b/src/com/android/nn/crashtest/app/NNMultipleProcessTest.java
index 80b0be0..4c81b93 100644
--- a/src/com/android/nn/crashtest/app/NNMultipleProcessTest.java
+++ b/src/com/android/nn/crashtest/app/NNMultipleProcessTest.java
@@ -81,7 +81,7 @@
protected Optional<TestModels.TestModelEntry> findModelForLivenessTest()
throws NnApiDelegationFailure {
- return findTestModelRunningOnAccelerator(
+ return AcceleratorSpecificTestSupport.findTestModelRunningOnAccelerator(
getInstrumentation().getTargetContext(), mAcceleratorName);
}
diff --git a/src/com/android/nn/crashtest/app/NNPerformanceDegradationTest.java b/src/com/android/nn/crashtest/app/NNPerformanceDegradationTest.java
new file mode 100644
index 0000000..b91050f
--- /dev/null
+++ b/src/com/android/nn/crashtest/app/NNPerformanceDegradationTest.java
@@ -0,0 +1,101 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.nn.crashtest.app;
+
+import android.content.Intent;
+import android.test.ActivityInstrumentationTestCase2;
+import android.test.UiThreadTest;
+
+import androidx.test.InstrumentationRegistry;
+
+import com.android.nn.benchmark.app.BenchmarkTestBase;
+import com.android.nn.crashtest.core.test.PerformanceDegradationTest;
+
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TestName;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+import java.util.Collections;
+
+@RunWith(Parameterized.class)
+public class NNPerformanceDegradationTest extends
+ ActivityInstrumentationTestCase2<NNPerformanceDegradationTestActivity> implements
+ AcceleratorSpecificTestSupport {
+ public static final String TAG = PerformanceDegradationTest.TAG;
+
+
+ @Parameters(name = "Threshold {1} % and {0} extra compiling threads on accelerator {2}")
+ public static Iterable<Object[]> testConfiguration() {
+ return AcceleratorSpecificTestSupport.perAcceleratorTestConfig(
+ Collections.singletonList(
+ new Object[]{/*thread count*/1, /*max degradation percent*/50}));
+ }
+
+ public NNPerformanceDegradationTest(int threadCount,
+ int maxPerformanceDegradationPercent, String acceleratorName) {
+ super(NNPerformanceDegradationTestActivity.class);
+ mAcceleratorName = acceleratorName;
+ mThreadCount = threadCount;
+ mMaxPerformanceDegradationPercent = maxPerformanceDegradationPercent;
+ }
+
+
+ @Rule
+ public TestName mTestName = new TestName();
+
+ private static final float WARM_UP_TIME_SECONDS = 5;
+ private static final float RUN_TIME_SECONDS = 20;
+
+ private final String mAcceleratorName;
+ private final int mThreadCount;
+ private final int mMaxPerformanceDegradationPercent;
+
+ protected static Intent getTestMaxPerfDegradationOfModelWIthThreads(String testName,
+ String acceleratorName, int threadCount,
+ int maxPerformanceDegradationPercent) {
+ Intent result = new Intent();
+ PerformanceDegradationTest.intentInitializer(WARM_UP_TIME_SECONDS, RUN_TIME_SECONDS,
+ acceleratorName, threadCount, maxPerformanceDegradationPercent,
+ testName).addIntentParams(result);
+ return result;
+ }
+
+ @Test
+ @UiThreadTest
+ public void shouldNotDegradePerformanceOverThreshold() {
+ CrashTestStatus.TestResult testResult = getActivity().testResult();
+ assertEquals("Test didn't complete successfully", CrashTestStatus.TestResult.SUCCESS,
+ testResult);
+ }
+
+ @Before
+ @Override
+ public void setUp() {
+ injectInstrumentation(InstrumentationRegistry.getInstrumentation());
+ BenchmarkTestBase.waitUntilCharged(getInstrumentation().getTargetContext(), 60);
+ setActivityIntent(getTestMaxPerfDegradationOfModelWIthThreads(mTestName.getMethodName(),
+ mAcceleratorName,
+ mThreadCount,
+ mMaxPerformanceDegradationPercent));
+ }
+
+
+}
diff --git a/src/com/android/nn/crashtest/app/NNPerformanceDegradationTestActivity.java b/src/com/android/nn/crashtest/app/NNPerformanceDegradationTestActivity.java
new file mode 100644
index 0000000..57c3b6b
--- /dev/null
+++ b/src/com/android/nn/crashtest/app/NNPerformanceDegradationTestActivity.java
@@ -0,0 +1,64 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.nn.crashtest.app;
+
+import android.content.Intent;
+
+import com.android.nn.benchmark.core.TestModels;
+import com.android.nn.crashtest.core.CrashTest;
+import com.android.nn.crashtest.core.CrashTestCoordinator;
+import com.android.nn.crashtest.core.test.PerformanceDegradationTest;
+
+public class NNPerformanceDegradationTestActivity extends NNCrashTestActivity {
+ private static final String TAG = NNPerformanceDegradationTest.TAG;
+
+ @Override
+ protected String getTag() {
+ return TAG;
+ }
+
+ @Override
+ protected String getTestName(Intent intent) {
+ return intent.getStringExtra(PerformanceDegradationTest.TEST_NAME);
+ }
+
+ @Override
+ protected long getTestDurationMillis(Intent intent) {
+ final float testBenchmarkRuntimeSeconds = intent.getFloatExtra(
+ PerformanceDegradationTest.RUN_TIME_SECONDS,
+ PerformanceDegradationTest.DEFAULT_RUN_TIME_SECONDS);
+ final float testBenchmarkWarmupTimeSeconds = intent.getFloatExtra(
+ PerformanceDegradationTest.WARMUP_SECONDS,
+ PerformanceDegradationTest.DEFAULT_WARMUP_SECONDS);
+ // Two cycles of performance measurement are taken, single and multi-threaded,
+ // with a pause in the middle repeated for every available model. We are assuming all are
+ // available
+ long oneModelTestDuration =
+ (long) (testBenchmarkRuntimeSeconds + testBenchmarkWarmupTimeSeconds) * 1000 * 3;
+ return oneModelTestDuration * TestModels.modelsList().size();
+ }
+
+ @Override
+ protected CrashTestCoordinator.CrashTestIntentInitializer getIntentInitializer(Intent intent) {
+ return PerformanceDegradationTest.intentInitializer(intent);
+ }
+
+ @Override
+ protected Class<? extends CrashTest> getTestClass() {
+ return PerformanceDegradationTest.class;
+ }
+}
\ No newline at end of file
diff --git a/src/com/android/nn/crashtest/app/NNRandomGraphTest.java b/src/com/android/nn/crashtest/app/NNRandomGraphTest.java
index 86e3b0f..b477949 100644
--- a/src/com/android/nn/crashtest/app/NNRandomGraphTest.java
+++ b/src/com/android/nn/crashtest/app/NNRandomGraphTest.java
@@ -99,7 +99,7 @@
protected Optional<TestModels.TestModelEntry> findModelForLivenessTest()
throws NnApiDelegationFailure {
- return findTestModelRunningOnAccelerator(
+ return AcceleratorSpecificTestSupport.findTestModelRunningOnAccelerator(
getInstrumentation().getTargetContext(), mAcceleratorName);
}
diff --git a/src/com/android/nn/crashtest/core/CrashTest.java b/src/com/android/nn/crashtest/core/CrashTest.java
index 68b70ae..34a7e0f 100644
--- a/src/com/android/nn/crashtest/core/CrashTest.java
+++ b/src/com/android/nn/crashtest/core/CrashTest.java
@@ -35,4 +35,5 @@
default Optional<String> success() { return Optional.empty(); }
default Optional<String> failure(String reason) { return Optional.of(reason); }
+ default boolean isFailure(Optional<String> result) { return result.isPresent(); }
}
diff --git a/src/com/android/nn/crashtest/core/test/PerformanceDegradationTest.java b/src/com/android/nn/crashtest/core/test/PerformanceDegradationTest.java
new file mode 100644
index 0000000..038d554
--- /dev/null
+++ b/src/com/android/nn/crashtest/core/test/PerformanceDegradationTest.java
@@ -0,0 +1,289 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.nn.crashtest.core.test;
+
+import android.annotation.SuppressLint;
+import android.content.Context;
+import android.content.Intent;
+import android.util.Log;
+
+import com.android.nn.benchmark.core.BenchmarkException;
+import com.android.nn.benchmark.core.BenchmarkResult;
+import com.android.nn.benchmark.core.Processor;
+import com.android.nn.benchmark.core.TestModels;
+import com.android.nn.crashtest.app.AcceleratorSpecificTestSupport;
+import com.android.nn.crashtest.core.CrashTest;
+import com.android.nn.crashtest.core.CrashTestCoordinator;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Optional;
+import java.util.concurrent.Callable;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.stream.Stream;
+
+public class PerformanceDegradationTest implements CrashTest {
+ public static final String TAG = "NN_PERF_DEG";
+
+ private static final Processor.Callback mNoOpCallback = new Processor.Callback() {
+ @Override
+ public void onBenchmarkFinish(boolean ok) {
+ }
+
+ @Override
+ public void onStatusUpdate(int testNumber, int numTests, String modelName) {
+ }
+ };
+
+ public static final String WARMUP_SECONDS = "warmup_seconds";
+ public static final String RUN_TIME_SECONDS = "run_time_seconds";
+ public static final String ACCELERATOR_NAME = "accelerator_name";
+ public static final float DEFAULT_WARMUP_SECONDS = 3.0f;
+ public static final float DEFAULT_RUN_TIME_SECONDS = 10.0f;
+ public static final String THREAD_COUNT = "thread_count";
+ public static final int DEFAULT_THREAD_COUNT = 5;
+ public static final String MAX_PERFORMANCE_DEGRADATION = "max_performance_degradation";
+ public static final int DEFAULT_MAX_PERFORMANCE_DEGRADATION_PERCENTAGE = 100;
+ public static final String TEST_NAME = "test_name";
+ private static final long INTERVAL_BETWEEN_PERFORMANCE_MEASUREMENTS_MS = 500;
+
+ static public CrashTestCoordinator.CrashTestIntentInitializer intentInitializer(
+ float warmupTimeSeconds, float runTimeSeconds, String acceleratorName, int threadCount,
+ int maxPerformanceDegradationPercent, String testName) {
+ return intent -> {
+ intent.putExtra(WARMUP_SECONDS, warmupTimeSeconds);
+ intent.putExtra(RUN_TIME_SECONDS, runTimeSeconds);
+ intent.putExtra(ACCELERATOR_NAME, acceleratorName);
+ intent.putExtra(THREAD_COUNT, threadCount);
+ intent.putExtra(MAX_PERFORMANCE_DEGRADATION, maxPerformanceDegradationPercent);
+ intent.putExtra(TEST_NAME, testName);
+ };
+ }
+
+ static public CrashTestCoordinator.CrashTestIntentInitializer intentInitializer(
+ Intent copyFrom) {
+ return intentInitializer(
+ copyFrom.getFloatExtra(WARMUP_SECONDS, DEFAULT_WARMUP_SECONDS),
+ copyFrom.getFloatExtra(RUN_TIME_SECONDS, DEFAULT_RUN_TIME_SECONDS),
+ copyFrom.getStringExtra(ACCELERATOR_NAME),
+ copyFrom.getIntExtra(THREAD_COUNT, DEFAULT_THREAD_COUNT),
+ copyFrom.getIntExtra(MAX_PERFORMANCE_DEGRADATION,
+ DEFAULT_MAX_PERFORMANCE_DEGRADATION_PERCENTAGE),
+ copyFrom.getStringExtra(TEST_NAME));
+ }
+
+ private Context mContext;
+ private float mWarmupTimeSeconds;
+ private float mRunTimeSeconds;
+ private String mAcceleratorName;
+ private int mThreadCount;
+ private int mMaxPerformanceDegradationPercent;
+ private String mTestName;
+
+ @Override
+ public void init(Context context, Intent configParams,
+ Optional<ProgressListener> progressListener) {
+ mContext = context;
+
+ mWarmupTimeSeconds = configParams.getFloatExtra(WARMUP_SECONDS, DEFAULT_WARMUP_SECONDS);
+ mRunTimeSeconds = configParams.getFloatExtra(RUN_TIME_SECONDS, DEFAULT_RUN_TIME_SECONDS);
+ mAcceleratorName = configParams.getStringExtra(ACCELERATOR_NAME);
+ mThreadCount = configParams.getIntExtra(THREAD_COUNT, DEFAULT_THREAD_COUNT);
+ mMaxPerformanceDegradationPercent = configParams.getIntExtra(MAX_PERFORMANCE_DEGRADATION,
+ DEFAULT_MAX_PERFORMANCE_DEGRADATION_PERCENTAGE);
+ mTestName = configParams.getStringExtra(TEST_NAME);
+ }
+
+ @SuppressLint("DefaultLocale")
+ @Override
+ public Optional<String> call() throws Exception {
+ List<TestModels.TestModelEntry> modelsForAccelerator =
+ AcceleratorSpecificTestSupport.findAllTestModelsRunningOnAccelerator(mContext,
+ mAcceleratorName);
+
+ if (modelsForAccelerator.isEmpty()) {
+ return failure("Cannot find any model to use for testing");
+ }
+
+ Log.i(TAG, String.format("Checking performance degradation using %d models",
+ modelsForAccelerator.size()));
+
+ TestModels.TestModelEntry modelForInference = modelsForAccelerator.get(0);
+ // The performance degradation is strongly dependent on the model used to compile
+ // so we check all the available ones.
+ for (TestModels.TestModelEntry modelForCompilation : modelsForAccelerator) {
+ Optional<String> currTestResult = testDegradationForModels(modelForInference,
+ modelForCompilation);
+ if (isFailure(currTestResult)) {
+ return currTestResult;
+ }
+ }
+
+ return success();
+ }
+
+ @SuppressLint("DefaultLocale")
+ public Optional<String> testDegradationForModels(
+ TestModels.TestModelEntry inferenceModelEntry,
+ TestModels.TestModelEntry compilationModelEntry) throws Exception {
+ Log.i(TAG, String.format(
+ "Testing degradation in inference of model %s when running %d threads compliing "
+ + "model %s",
+ inferenceModelEntry.mModelName, mThreadCount, compilationModelEntry.mModelName));
+
+ Log.d(TAG, String.format("%s: Calculating baseline", mTestName));
+ // first let's measure a baseline performance
+ final BenchmarkResult baseline = modelPerformanceCollector(inferenceModelEntry,
+ /*start=*/ null).call();
+ if (baseline.hasBenchmarkError()) {
+ return failure(String.format("%s: Baseline has benchmark error '%s'",
+ mTestName, baseline.getBenchmarkError()));
+ }
+ Log.d(TAG, String.format("%s: Baseline mean time is %f seconds", mTestName,
+ baseline.getMeanTimeSec()));
+
+ Log.d(TAG, String.format("%s: Sleeping for %d millis", mTestName,
+ INTERVAL_BETWEEN_PERFORMANCE_MEASUREMENTS_MS));
+ Thread.sleep(INTERVAL_BETWEEN_PERFORMANCE_MEASUREMENTS_MS);
+
+ Log.d(TAG, String.format("%s: Calculating performance with %d threads", mTestName,
+ mThreadCount));
+ final int totalThreadCount = mThreadCount + 1;
+ final CountDownLatch start = new CountDownLatch(totalThreadCount);
+ ModelCompiler[] compilers = Stream.generate(
+ () -> new ModelCompiler(start, mContext, mAcceleratorName,
+ compilationModelEntry)).limit(
+ mThreadCount).toArray(
+ ModelCompiler[]::new);
+
+ Callable<BenchmarkResult> performanceWithOtherCompilingThreadCollector =
+ modelPerformanceCollector(inferenceModelEntry, start);
+
+ ExecutorService testExecutor = Executors.newFixedThreadPool(totalThreadCount);
+ Future<?>[] compilerFutures = Arrays.stream(compilers).map(testExecutor::submit).toArray(
+ Future[]::new);
+ BenchmarkResult benchmarkWithOtherCompilingThread = testExecutor.submit(
+ performanceWithOtherCompilingThreadCollector).get();
+
+ Arrays.stream(compilers).forEach(ModelCompiler::stop);
+ Arrays.stream(compilerFutures).forEach(future -> {
+ try {
+ future.get();
+ } catch (InterruptedException | ExecutionException e) {
+ Log.e(TAG, "Error waiting for compiler process completion", e);
+ }
+ });
+
+ if (benchmarkWithOtherCompilingThread.hasBenchmarkError()) {
+ return failure(
+ String.format(
+ "%s: Test with parallel compiling thrads has benchmark error '%s'",
+ mTestName, benchmarkWithOtherCompilingThread.getBenchmarkError()));
+ }
+
+ Log.d(TAG, String.format("%s: Multithreaded mean time is %f seconds",
+ mTestName, benchmarkWithOtherCompilingThread.getMeanTimeSec()));
+
+ int performanceDegradation = (int) (((benchmarkWithOtherCompilingThread.getMeanTimeSec()
+ / baseline.getMeanTimeSec()) - 1.0) * 100);
+
+ Log.i(TAG, String.format(
+ "%s: Performance degradation for accelerator %s, with %d threads is %d%%. "
+ + "Threshold "
+ + "is %d%%",
+ mTestName, mAcceleratorName, mThreadCount, performanceDegradation,
+ mMaxPerformanceDegradationPercent));
+
+ if (performanceDegradation > mMaxPerformanceDegradationPercent) {
+ return failure(String.format("Performance degradation is %d%%. Max acceptable is %d%%",
+ performanceDegradation, mMaxPerformanceDegradationPercent));
+ }
+
+ return success();
+ }
+
+
+ private Callable<BenchmarkResult> modelPerformanceCollector(
+ final TestModels.TestModelEntry inferenceModelEntry, final CountDownLatch start) {
+ return () -> {
+ Processor benchmarkProcessor = new Processor(mContext, mNoOpCallback, new int[0]);
+ benchmarkProcessor.setUseNNApi(true);
+ benchmarkProcessor.setNnApiAcceleratorName(mAcceleratorName);
+ if (start != null) {
+ start.countDown();
+ start.await();
+ }
+ final BenchmarkResult result =
+ benchmarkProcessor.getInstrumentationResult(
+ inferenceModelEntry, mWarmupTimeSeconds, mRunTimeSeconds);
+
+ return result;
+ };
+ }
+
+ private static class ModelCompiler implements Callable<Void> {
+ private static final long SLEEP_BETWEEN_COMPILATION_INTERVAL_MS = 20;
+ private final CountDownLatch mStart;
+ private final Processor mProcessor;
+ private final TestModels.TestModelEntry mTestModelEntry;
+ private volatile boolean mRun;
+
+ ModelCompiler(final CountDownLatch start, final Context context,
+ final String acceleratorName, TestModels.TestModelEntry testModelEntry) {
+ mStart = start;
+ mTestModelEntry = testModelEntry;
+ mProcessor = new Processor(context, mNoOpCallback, new int[0]);
+ mProcessor.setUseNNApi(true);
+ mProcessor.setNnApiAcceleratorName(acceleratorName);
+ mProcessor.setRunModelCompilationOnly(true);
+ mRun = true;
+ }
+
+ @Override
+ public Void call() throws IOException, BenchmarkException {
+ if (mStart != null) {
+ try {
+ mStart.countDown();
+ mStart.await();
+ } catch (InterruptedException e) {
+ Thread.interrupted();
+ Log.i(TAG, "Interrupted, stopping processing");
+ return null;
+ }
+ }
+ while (mRun) {
+ mProcessor.getInstrumentationResult(mTestModelEntry, 0, 0);
+ try {
+ Thread.sleep(SLEEP_BETWEEN_COMPILATION_INTERVAL_MS);
+ } catch (InterruptedException e) {
+ Thread.interrupted();
+ return null;
+ }
+ }
+ return null;
+ }
+
+ public void stop() {
+ mRun = false;
+ }
+ }
+}