Always use all available models in NN*Test am: 1927d05e51 am: 45f9bbfb1f am: a1cc9ea636
am: 887c50ef55
Change-Id: Ic5341d7d2c2fed86d4345c2b2c4745825bb2b3ed
diff --git a/build_and_run_benchmark.sh b/build_and_run_benchmark.sh
index ab40c49..90331ad 100755
--- a/build_and_run_benchmark.sh
+++ b/build_and_run_benchmark.sh
@@ -4,7 +4,7 @@
#
# Output is logged to a temporary folder and summarized in txt and JSON formats.
-MODE="${1:=scoring}"
+MODE="${1:-scoring}"
case "$MODE" in
scoring)
diff --git a/src/com/android/nn/benchmark/app/NNInferenceStressTest.java b/src/com/android/nn/benchmark/app/NNInferenceStressTest.java
index 9b4fdf4..0773418 100644
--- a/src/com/android/nn/benchmark/app/NNInferenceStressTest.java
+++ b/src/com/android/nn/benchmark/app/NNInferenceStressTest.java
@@ -22,6 +22,7 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
+import java.util.stream.Collectors;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@@ -35,7 +36,6 @@
public class NNInferenceStressTest extends BenchmarkTestBase {
private static final String TAG = NNInferenceStressTest.class.getSimpleName();
- private static final String[] MODEL_NAMES = NNScoringTest.MODEL_NAMES;
private static final float WARMUP_SECONDS = 0; // No warmup.
private static final float RUNTIME_SECONDS = 60 * 60; // 1 hour.
@@ -45,22 +45,21 @@
@Parameters(name = "{0}")
public static List<TestModels.TestModelEntry> modelsList() {
- List<TestModels.TestModelEntry> models = new ArrayList<>();
- for (String modelName : MODEL_NAMES) {
- TestModels.TestModelEntry model = TestModels.getModelByName(modelName);
- models.add(
- new TestModels.TestModelEntry(
- model.mModelName,
- model.mBaselineSec,
- model.mInputShape,
- model.mInOutAssets,
- model.mInOutDatasets,
- model.mTestName,
- model.mModelFile,
- null, // Disable evaluation.
- model.mMinSdkVersion));
- }
- return Collections.unmodifiableList(models);
+ return TestModels.modelsList().stream()
+ .map(model ->
+ new TestModels.TestModelEntry(
+ model.mModelName,
+ model.mBaselineSec,
+ model.mInputShape,
+ model.mInOutAssets,
+ model.mInOutDatasets,
+ model.mTestName,
+ model.mModelFile,
+ null, // Disable evaluation.
+ model.mMinSdkVersion))
+ .collect(Collectors.collectingAndThen(
+ Collectors.toList(),
+ Collections::unmodifiableList));
}
@Test
diff --git a/src/com/android/nn/benchmark/app/NNModelLoadingStressTest.java b/src/com/android/nn/benchmark/app/NNModelLoadingStressTest.java
index faf507e..03f5ac2 100644
--- a/src/com/android/nn/benchmark/app/NNModelLoadingStressTest.java
+++ b/src/com/android/nn/benchmark/app/NNModelLoadingStressTest.java
@@ -23,8 +23,9 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
-import java.util.List;
import java.util.concurrent.TimeUnit;
+import java.util.List;
+import java.util.stream.Collectors;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.Stopwatch;
@@ -37,7 +38,6 @@
public class NNModelLoadingStressTest extends BenchmarkTestBase {
private static final String TAG = NNModelLoadingStressTest.class.getSimpleName();
- private static final String[] MODEL_NAMES = NNScoringTest.MODEL_NAMES;
private static final float WARMUP_SECONDS = 0; // No warmup.
private static final float INFERENCE_SECONDS = 0; // No inference.
private static final float RUNTIME_SECONDS = 30 * 60;
@@ -50,22 +50,21 @@
@Parameters(name = "{0}")
public static List<TestModels.TestModelEntry> modelsList() {
- List<TestModels.TestModelEntry> models = new ArrayList<>();
- for (String modelName : MODEL_NAMES) {
- TestModels.TestModelEntry model = TestModels.getModelByName(modelName);
- models.add(
- new TestModels.TestModelEntry(
- model.mModelName,
- model.mBaselineSec,
- model.mInputShape,
- new InferenceInOutSequence.FromAssets[0], // No inputs for inference.
- null,
- model.mTestName,
- model.mModelFile,
- null, // Disable evaluation.
- model.mMinSdkVersion));
- }
- return Collections.unmodifiableList(models);
+ return TestModels.modelsList().stream()
+ .map(model ->
+ new TestModels.TestModelEntry(
+ model.mModelName,
+ model.mBaselineSec,
+ model.mInputShape,
+ new InferenceInOutSequence.FromAssets[0], // No inputs for inference.
+ null,
+ model.mTestName,
+ model.mModelFile,
+ null, // Disable evaluation.
+ model.mMinSdkVersion))
+ .collect(Collectors.collectingAndThen(
+ Collectors.toList(),
+ Collections::unmodifiableList));
}
@Test
diff --git a/src/com/android/nn/benchmark/app/NNScoringTest.java b/src/com/android/nn/benchmark/app/NNScoringTest.java
index d987663..f3cb5a7 100644
--- a/src/com/android/nn/benchmark/app/NNScoringTest.java
+++ b/src/com/android/nn/benchmark/app/NNScoringTest.java
@@ -61,34 +61,6 @@
super.prepareTest();
}
- // Shared with NNStressTest.
- static final String[] MODEL_NAMES = new String[]{
- "tts_float",
- "asr_float",
- "mobilenet_v1_1.0_224_quant_topk_aosp",
- "mobilenet_v1_1.0_224_topk_aosp",
- "mobilenet_v1_0.75_192_quant_topk_aosp",
- "mobilenet_v1_0.75_192_topk_aosp",
- "mobilenet_v1_0.5_160_quant_topk_aosp",
- "mobilenet_v1_0.5_160_topk_aosp",
- "mobilenet_v1_0.25_128_quant_topk_aosp",
- "mobilenet_v1_0.25_128_topk_aosp",
- "mobilenet_v2_0.35_128_topk_aosp",
- "mobilenet_v2_0.5_160_topk_aosp",
- "mobilenet_v2_0.75_192_topk_aosp",
- "mobilenet_v2_1.0_224_topk_aosp",
- "mobilenet_v2_1.0_224_quant_topk_aosp",
- };
-
- @Parameters(name = "{0}")
- public static List<TestModels.TestModelEntry> modelsList() {
- List<TestModels.TestModelEntry> models = new ArrayList<>();
- for (String modelName : MODEL_NAMES) {
- models.add(TestModels.getModelByName(modelName));
- }
- return Collections.unmodifiableList(models);
- }
-
@Test
@LargeTest
public void testTFLite() throws IOException {