Improve script for building and running dump intermediate tensors. am: 0288bc3de7
Original change: https://android-review.googlesource.com/c/platform/test/mlts/benchmark/+/1393199
Change-Id: Iebb5234cdca4a9c578553b1f683f8533acf886c6
diff --git a/src/com/android/nn/benchmark/app/BenchmarkTestBase.java b/src/com/android/nn/benchmark/app/BenchmarkTestBase.java
index 18241f6..199d3d7 100644
--- a/src/com/android/nn/benchmark/app/BenchmarkTestBase.java
+++ b/src/com/android/nn/benchmark/app/BenchmarkTestBase.java
@@ -173,28 +173,29 @@
class TestAction implements Joinable {
private final TestModelEntry mTestModel;
- private final float mWarmupTimeSeconds;
- private final float mRunTimeSeconds;
+ private final float mMaxWarmupTimeSeconds;
+ private final float mMaxRunTimeSeconds;
private final CountDownLatch actionComplete;
BenchmarkResult mResult;
Throwable mException;
- public TestAction(TestModelEntry testName, float warmupTimeSeconds, float runTimeSeconds) {
+ public TestAction(TestModelEntry testName, float maxWarmupTimeSeconds,
+ float maxRunTimeSeconds) {
mTestModel = testName;
- mWarmupTimeSeconds = warmupTimeSeconds;
- mRunTimeSeconds = runTimeSeconds;
+ mMaxWarmupTimeSeconds = maxWarmupTimeSeconds;
+ mMaxRunTimeSeconds = maxRunTimeSeconds;
actionComplete = new CountDownLatch(1);
}
public void run() {
Log.v(NNBenchmark.TAG, String.format(
- "Starting benchmark for test '%s' running for at least %f seconds",
+ "Starting benchmark for test '%s' running for max %f seconds",
mTestModel.mTestName,
- mRunTimeSeconds));
+ mMaxRunTimeSeconds));
try {
mResult = mActivity.runSynchronously(
- mTestModel, mWarmupTimeSeconds, mRunTimeSeconds);
+ mTestModel, mMaxWarmupTimeSeconds, mMaxRunTimeSeconds);
} catch (BenchmarkException | IOException e) {
mException = e;
Log.e(NNBenchmark.TAG,
diff --git a/src/com/android/nn/benchmark/app/NNScoringTest.java b/src/com/android/nn/benchmark/app/NNScoringTest.java
index 33e1402..a909450 100644
--- a/src/com/android/nn/benchmark/app/NNScoringTest.java
+++ b/src/com/android/nn/benchmark/app/NNScoringTest.java
@@ -70,7 +70,7 @@
setCompleteInputSet(useCompleteInputSet);
enableCompilationCachingBenchmarks();
TestAction ta = new TestAction(mModel, WARMUP_REPEATABLE_SECONDS,
- RUNTIME_REPEATABLE_SECONDS);
+ useCompleteInputSet ? COMPLETE_SET_TIMEOUT_SECOND : RUNTIME_REPEATABLE_SECONDS);
runTest(ta, mModel.getTestName());
try (CSVWriter writer = new CSVWriter(getLocalCSVFile())) {
diff --git a/src/com/android/nn/benchmark/core/NNTestBase.java b/src/com/android/nn/benchmark/core/NNTestBase.java
index 745222b..455ab5e 100644
--- a/src/com/android/nn/benchmark/core/NNTestBase.java
+++ b/src/com/android/nn/benchmark/core/NNTestBase.java
@@ -320,9 +320,11 @@
flags);
if (result.second.size() != extpectedResults) {
// We reached a timeout or failed to evaluate whole set for other reason, abort.
- final String errorMsg = "Failed to evaluate complete input set, expected: "
- + extpectedResults +
- ", received: " + result.second.size();
+ @SuppressLint("DefaultLocale")
+ final String errorMsg = String.format(
+ "Failed to evaluate complete input set, in %d seconds expected: %d, received:"
+ + " %d",
+ timeoutSec, extpectedResults, result.second.size());
Log.w(TAG, errorMsg);
throw new IllegalStateException(errorMsg);
}
@@ -416,7 +418,7 @@
}
@Override
- public void close() {
+ public void close() {
destroy();
}
}
diff --git a/src/com/android/nn/benchmark/core/Processor.java b/src/com/android/nn/benchmark/core/Processor.java
index 778e5d0..200bf6b 100644
--- a/src/com/android/nn/benchmark/core/Processor.java
+++ b/src/com/android/nn/benchmark/core/Processor.java
@@ -185,18 +185,18 @@
return tb;
}
- // Run one loop of kernels for at least the specified minimum time.
+ // Run one loop of kernels for at most the specified minimum time.
// The function returns the average time in ms for the test run
- private BenchmarkResult runBenchmarkLoop(float minTime, boolean completeInputSet)
+ private BenchmarkResult runBenchmarkLoop(float maxTime, boolean completeInputSet)
throws IOException {
try {
// Run the kernel
Pair<List<InferenceInOutSequence>, List<InferenceResult>> results;
- if (minTime > 0.f) {
+ if (maxTime > 0.f) {
if (completeInputSet) {
- results = mTest.runBenchmarkCompleteInputSet(1, minTime);
+ results = mTest.runBenchmarkCompleteInputSet(1, maxTime);
} else {
- results = mTest.runBenchmark(minTime);
+ results = mTest.runBenchmark(maxTime);
}
} else {
results = mTest.runInferenceOnce();