Reduce number of results collected in long tests am: b9965e9588 am: 529ef97735
Original change: https://android-review.googlesource.com/c/platform/test/mlts/benchmark/+/1418596
Change-Id: Ica827fbadbbb8e7986f77b92f14b1cbd3aac82ae
diff --git a/jni/run_tflite.cpp b/jni/run_tflite.cpp
index 481abee..8362af0 100644
--- a/jni/run_tflite.cpp
+++ b/jni/run_tflite.cpp
@@ -277,6 +277,7 @@
const int inputOutputSequenceIndex = seqInferenceIndex % inOutData.size();
const InferenceInOutSequence& seq = inOutData[inputOutputSequenceIndex];
+ const bool sampleResults = (flags & FLAG_SAMPLE_BENCHMARK_RESULTS) != 0;
for (int i = 0; i < seq.size(); ++i) {
const InferenceInOut& data = seq[i];
@@ -333,7 +334,10 @@
saveInferenceOutput(&result, j);
}
}
- results->push_back(result);
+
+ if (!sampleResults || (seqInferenceIndex % INFERENCE_OUT_SAMPLE_RATE) == 0) {
+ results->push_back(result);
+ }
inferenceTotal += inferenceTime;
}
diff --git a/jni/run_tflite.h b/jni/run_tflite.h
index 978348e..9e0f086 100644
--- a/jni/run_tflite.h
+++ b/jni/run_tflite.h
@@ -69,6 +69,10 @@
const int FLAG_DISCARD_INFERENCE_OUTPUT = 1 << 0;
/** Do not expect golden output for inference inputs. */
const int FLAG_IGNORE_GOLDEN_OUTPUT = 1 << 1;
+/** Collect only 1 benchmark result every INFERENCE_OUT_SAMPLE_RATE **/
+const int FLAG_SAMPLE_BENCHMARK_RESULTS = 1 << 2;
+
+const int INFERENCE_OUT_SAMPLE_RATE = 10;
enum class CompilationBenchmarkType {
// Benchmark without cache
diff --git a/src/com/android/nn/benchmark/app/BenchmarkTestBase.java b/src/com/android/nn/benchmark/app/BenchmarkTestBase.java
index 199d3d7..b2e9e3a 100644
--- a/src/com/android/nn/benchmark/app/BenchmarkTestBase.java
+++ b/src/com/android/nn/benchmark/app/BenchmarkTestBase.java
@@ -176,15 +176,22 @@
private final float mMaxWarmupTimeSeconds;
private final float mMaxRunTimeSeconds;
private final CountDownLatch actionComplete;
+ private final boolean mSampleResults;
BenchmarkResult mResult;
Throwable mException;
public TestAction(TestModelEntry testName, float maxWarmupTimeSeconds,
float maxRunTimeSeconds) {
+ this(testName, maxWarmupTimeSeconds, maxRunTimeSeconds, false);
+ }
+
+ public TestAction(TestModelEntry testName, float maxWarmupTimeSeconds,
+ float maxRunTimeSeconds, boolean sampleResults) {
mTestModel = testName;
mMaxWarmupTimeSeconds = maxWarmupTimeSeconds;
mMaxRunTimeSeconds = maxRunTimeSeconds;
+ mSampleResults = sampleResults;
actionComplete = new CountDownLatch(1);
}
@@ -195,7 +202,7 @@
mMaxRunTimeSeconds));
try {
mResult = mActivity.runSynchronously(
- mTestModel, mMaxWarmupTimeSeconds, mMaxRunTimeSeconds);
+ mTestModel, mMaxWarmupTimeSeconds, mMaxRunTimeSeconds, mSampleResults);
} catch (BenchmarkException | IOException e) {
mException = e;
Log.e(NNBenchmark.TAG,
@@ -242,7 +249,9 @@
final String traceName = "[NN_LA_PO]" + testName;
try {
Trace.beginSection(traceName);
+ Log.i(NNBenchmark.TAG, "Starting test " + testName);
runOnUiThread(ta);
+ Log.i(NNBenchmark.TAG, "Test " + testName + " completed");
} finally {
Trace.endSection();
}
diff --git a/src/com/android/nn/benchmark/app/NNBenchmark.java b/src/com/android/nn/benchmark/app/NNBenchmark.java
index 6253167..a87760d 100644
--- a/src/com/android/nn/benchmark/app/NNBenchmark.java
+++ b/src/com/android/nn/benchmark/app/NNBenchmark.java
@@ -138,7 +138,7 @@
}
public BenchmarkResult runSynchronously(TestModelEntry testModel,
- float warmupTimeSeconds, float runTimeSeconds) throws IOException, BenchmarkException {
- return mProcessor.getInstrumentationResult(testModel, warmupTimeSeconds, runTimeSeconds);
+ float warmupTimeSeconds, float runTimeSeconds, boolean sampleResults) throws IOException, BenchmarkException {
+ return mProcessor.getInstrumentationResult(testModel, warmupTimeSeconds, runTimeSeconds, sampleResults);
}
}
diff --git a/src/com/android/nn/benchmark/app/NNInferenceStressTest.java b/src/com/android/nn/benchmark/app/NNInferenceStressTest.java
index f8ee7a5..270c4fe 100644
--- a/src/com/android/nn/benchmark/app/NNInferenceStressTest.java
+++ b/src/com/android/nn/benchmark/app/NNInferenceStressTest.java
@@ -17,17 +17,20 @@
package com.android.nn.benchmark.app;
import android.test.suitebuilder.annotation.LargeTest;
+
import com.android.nn.benchmark.core.TestModels;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
-import java.util.stream.Collectors;
+
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameters;
+import java.io.IOException;
+import java.time.Duration;
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
+
/**
* Tests that ensure stability of NNAPI by running inference for a
* prolonged period of time.
@@ -37,7 +40,8 @@
private static final String TAG = NNInferenceStressTest.class.getSimpleName();
private static final float WARMUP_SECONDS = 0; // No warmup.
- private static final float RUNTIME_SECONDS = 60 * 60; // 1 hour.
+ private static final float RUNTIME_SECONDS = Duration.ofHours(1).getSeconds();
+ private static final long LONG_STRESS_TEST_DURATION_SECONDS = Duration.ofHours(4).getSeconds();
public NNInferenceStressTest(TestModels.TestModelEntry model) {
super(model);
@@ -58,7 +62,11 @@
waitUntilCharged();
setUseNNApi(true);
setCompleteInputSet(false);
- TestAction ta = new TestAction(mModel, WARMUP_SECONDS, RUNTIME_SECONDS);
+ // Will sample results for very long tests in order to avoid the results to saturate
+ // available memory.
+ final boolean shouldSampleResults = RUNTIME_SECONDS >= LONG_STRESS_TEST_DURATION_SECONDS;
+ TestAction ta = new TestAction(mModel, WARMUP_SECONDS, RUNTIME_SECONDS,
+ shouldSampleResults);
runTest(ta, mModel.getTestName());
}
}
diff --git a/src/com/android/nn/benchmark/core/NNTestBase.java b/src/com/android/nn/benchmark/core/NNTestBase.java
index 455ab5e..39c455d 100644
--- a/src/com/android/nn/benchmark/core/NNTestBase.java
+++ b/src/com/android/nn/benchmark/core/NNTestBase.java
@@ -105,6 +105,9 @@
public static final int FLAG_IGNORE_GOLDEN_OUTPUT = 1 << 1;
+ /** Collect only 1 benchmark result every 10 **/
+ public static final int FLAG_SAMPLE_BENCHMARK_RESULTS = 1 << 2;
+
protected Context mContext;
protected TextView mText;
private final String mModelName;
@@ -123,6 +126,7 @@
private boolean mMmapModel = false;
// Path where the current model has been stored for execution
private String mTemporaryModelFilePath;
+ private boolean mSampleResults;
public NNTestBase(String modelName, String modelFile, int[] inputShape,
InferenceInOutSequence.FromAssets[] inputOutputAssets,
@@ -145,6 +149,7 @@
mModelHandle = 0;
mEvaluatorConfig = evaluator;
mMinSdkVersion = minSdkVersion;
+ mSampleResults = false;
}
public void useNNApi() {
@@ -267,6 +272,10 @@
if (mEvaluator == null) {
flags = flags | FLAG_DISCARD_INFERENCE_OUTPUT;
}
+ // For very long tests we will collect only a sample of the results
+ if (mSampleResults) {
+ flags = flags | FLAG_SAMPLE_BENCHMARK_RESULTS;
+ }
return flags;
}
@@ -421,4 +430,8 @@
public void close() {
destroy();
}
+
+ public void setSampleResult(boolean sampleResults) {
+ this.mSampleResults = sampleResults;
+ }
}
diff --git a/src/com/android/nn/benchmark/core/Processor.java b/src/com/android/nn/benchmark/core/Processor.java
index 200bf6b..1cf605c 100644
--- a/src/com/android/nn/benchmark/core/Processor.java
+++ b/src/com/android/nn/benchmark/core/Processor.java
@@ -128,14 +128,23 @@
mCompilationBenchmarkMaxIterations = maxIterations;
}
- // Method to retrieve benchmark results for instrumentation tests.
- // Returns null if the processor is configured to run compilation only
public BenchmarkResult getInstrumentationResult(
TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds)
throws IOException, BenchmarkException {
+ return getInstrumentationResult(t, warmupTimeSeconds, runTimeSeconds, false);
+ }
+
+ // Method to retrieve benchmark results for instrumentation tests.
+ // Returns null if the processor is configured to run compilation only
+ public BenchmarkResult getInstrumentationResult(
+ TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds,
+ boolean sampleResults)
+ throws IOException, BenchmarkException {
mTest = changeTest(mTest, t);
+ mTest.setSampleResult(sampleResults);
try {
- BenchmarkResult result = mRunModelCompilationOnly ? null : getBenchmark(warmupTimeSeconds,
+ BenchmarkResult result = mRunModelCompilationOnly ? null : getBenchmark(
+ warmupTimeSeconds,
runTimeSeconds);
return result;
} finally {