Reduce number of results collected in long tests am: b9965e9588 am: 529ef97735

Original change: https://android-review.googlesource.com/c/platform/test/mlts/benchmark/+/1418596

Change-Id: Ica827fbadbbb8e7986f77b92f14b1cbd3aac82ae
diff --git a/jni/run_tflite.cpp b/jni/run_tflite.cpp
index 481abee..8362af0 100644
--- a/jni/run_tflite.cpp
+++ b/jni/run_tflite.cpp
@@ -277,6 +277,7 @@
 
     const int inputOutputSequenceIndex = seqInferenceIndex % inOutData.size();
     const InferenceInOutSequence& seq = inOutData[inputOutputSequenceIndex];
+    const bool sampleResults = (flags & FLAG_SAMPLE_BENCHMARK_RESULTS) != 0;
     for (int i = 0; i < seq.size(); ++i) {
       const InferenceInOut& data = seq[i];
 
@@ -333,7 +334,10 @@
           saveInferenceOutput(&result, j);
         }
       }
-      results->push_back(result);
+
+      if (!sampleResults || (seqInferenceIndex % INFERENCE_OUT_SAMPLE_RATE) == 0) {
+        results->push_back(result);
+      }
       inferenceTotal += inferenceTime;
     }
 
diff --git a/jni/run_tflite.h b/jni/run_tflite.h
index 978348e..9e0f086 100644
--- a/jni/run_tflite.h
+++ b/jni/run_tflite.h
@@ -69,6 +69,10 @@
 const int FLAG_DISCARD_INFERENCE_OUTPUT = 1 << 0;
 /** Do not expect golden output for inference inputs. */
 const int FLAG_IGNORE_GOLDEN_OUTPUT = 1 << 1;
+/** Collect only 1 benchmark result every INFERENCE_OUT_SAMPLE_RATE **/
+const int FLAG_SAMPLE_BENCHMARK_RESULTS = 1 << 2;
+
+const int INFERENCE_OUT_SAMPLE_RATE = 10;
 
 enum class CompilationBenchmarkType {
   // Benchmark without cache
diff --git a/src/com/android/nn/benchmark/app/BenchmarkTestBase.java b/src/com/android/nn/benchmark/app/BenchmarkTestBase.java
index 199d3d7..b2e9e3a 100644
--- a/src/com/android/nn/benchmark/app/BenchmarkTestBase.java
+++ b/src/com/android/nn/benchmark/app/BenchmarkTestBase.java
@@ -176,15 +176,22 @@
         private final float mMaxWarmupTimeSeconds;
         private final float mMaxRunTimeSeconds;
         private final CountDownLatch actionComplete;
+        private final boolean mSampleResults;
 
         BenchmarkResult mResult;
         Throwable mException;
 
         public TestAction(TestModelEntry testName, float maxWarmupTimeSeconds,
                 float maxRunTimeSeconds) {
+            this(testName, maxWarmupTimeSeconds, maxRunTimeSeconds, false);
+        }
+
+        public TestAction(TestModelEntry testName, float maxWarmupTimeSeconds,
+                float maxRunTimeSeconds, boolean sampleResults) {
             mTestModel = testName;
             mMaxWarmupTimeSeconds = maxWarmupTimeSeconds;
             mMaxRunTimeSeconds = maxRunTimeSeconds;
+            mSampleResults = sampleResults;
             actionComplete = new CountDownLatch(1);
         }
 
@@ -195,7 +202,7 @@
                     mMaxRunTimeSeconds));
             try {
                 mResult = mActivity.runSynchronously(
-                        mTestModel, mMaxWarmupTimeSeconds, mMaxRunTimeSeconds);
+                        mTestModel, mMaxWarmupTimeSeconds, mMaxRunTimeSeconds, mSampleResults);
             } catch (BenchmarkException | IOException e) {
                 mException = e;
                 Log.e(NNBenchmark.TAG,
@@ -242,7 +249,9 @@
         final String traceName = "[NN_LA_PO]" + testName;
         try {
             Trace.beginSection(traceName);
+            Log.i(NNBenchmark.TAG, "Starting test " + testName);
             runOnUiThread(ta);
+            Log.i(NNBenchmark.TAG, "Test " + testName + " completed");
         } finally {
             Trace.endSection();
         }
diff --git a/src/com/android/nn/benchmark/app/NNBenchmark.java b/src/com/android/nn/benchmark/app/NNBenchmark.java
index 6253167..a87760d 100644
--- a/src/com/android/nn/benchmark/app/NNBenchmark.java
+++ b/src/com/android/nn/benchmark/app/NNBenchmark.java
@@ -138,7 +138,7 @@
     }
 
     public BenchmarkResult runSynchronously(TestModelEntry testModel,
-        float warmupTimeSeconds, float runTimeSeconds) throws IOException, BenchmarkException {
-        return mProcessor.getInstrumentationResult(testModel, warmupTimeSeconds, runTimeSeconds);
+        float warmupTimeSeconds, float runTimeSeconds, boolean sampleResults) throws IOException, BenchmarkException {
+        return mProcessor.getInstrumentationResult(testModel, warmupTimeSeconds, runTimeSeconds, sampleResults);
     }
 }
diff --git a/src/com/android/nn/benchmark/app/NNInferenceStressTest.java b/src/com/android/nn/benchmark/app/NNInferenceStressTest.java
index f8ee7a5..270c4fe 100644
--- a/src/com/android/nn/benchmark/app/NNInferenceStressTest.java
+++ b/src/com/android/nn/benchmark/app/NNInferenceStressTest.java
@@ -17,17 +17,20 @@
 package com.android.nn.benchmark.app;
 
 import android.test.suitebuilder.annotation.LargeTest;
+
 import com.android.nn.benchmark.core.TestModels;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
-import java.util.stream.Collectors;
+
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
 import org.junit.runners.Parameterized.Parameters;
 
+import java.io.IOException;
+import java.time.Duration;
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
+
 /**
  * Tests that ensure stability of NNAPI by running inference for a
  * prolonged period of time.
@@ -37,7 +40,8 @@
     private static final String TAG = NNInferenceStressTest.class.getSimpleName();
 
     private static final float WARMUP_SECONDS = 0; // No warmup.
-    private static final float RUNTIME_SECONDS = 60 * 60; // 1 hour.
+    private static final float RUNTIME_SECONDS = Duration.ofHours(1).getSeconds();
+    private static final long LONG_STRESS_TEST_DURATION_SECONDS = Duration.ofHours(4).getSeconds();
 
     public NNInferenceStressTest(TestModels.TestModelEntry model) {
         super(model);
@@ -58,7 +62,11 @@
         waitUntilCharged();
         setUseNNApi(true);
         setCompleteInputSet(false);
-        TestAction ta = new TestAction(mModel, WARMUP_SECONDS, RUNTIME_SECONDS);
+        // Will sample results for very long tests in order to avoid the results to saturate
+        // available memory.
+        final boolean shouldSampleResults = RUNTIME_SECONDS >= LONG_STRESS_TEST_DURATION_SECONDS;
+        TestAction ta = new TestAction(mModel, WARMUP_SECONDS, RUNTIME_SECONDS,
+                shouldSampleResults);
         runTest(ta, mModel.getTestName());
     }
 }
diff --git a/src/com/android/nn/benchmark/core/NNTestBase.java b/src/com/android/nn/benchmark/core/NNTestBase.java
index 455ab5e..39c455d 100644
--- a/src/com/android/nn/benchmark/core/NNTestBase.java
+++ b/src/com/android/nn/benchmark/core/NNTestBase.java
@@ -105,6 +105,9 @@
     public static final int FLAG_IGNORE_GOLDEN_OUTPUT = 1 << 1;
 
 
+    /** Collect only 1 benchmark result every 10 **/
+    public static final int FLAG_SAMPLE_BENCHMARK_RESULTS = 1 << 2;
+
     protected Context mContext;
     protected TextView mText;
     private final String mModelName;
@@ -123,6 +126,7 @@
     private boolean mMmapModel = false;
     // Path where the current model has been stored for execution
     private String mTemporaryModelFilePath;
+    private boolean mSampleResults;
 
     public NNTestBase(String modelName, String modelFile, int[] inputShape,
             InferenceInOutSequence.FromAssets[] inputOutputAssets,
@@ -145,6 +149,7 @@
         mModelHandle = 0;
         mEvaluatorConfig = evaluator;
         mMinSdkVersion = minSdkVersion;
+        mSampleResults = false;
     }
 
     public void useNNApi() {
@@ -267,6 +272,10 @@
         if (mEvaluator == null) {
             flags = flags | FLAG_DISCARD_INFERENCE_OUTPUT;
         }
+        // For very long tests we will collect only a sample of the results
+        if (mSampleResults) {
+            flags = flags | FLAG_SAMPLE_BENCHMARK_RESULTS;
+        }
         return flags;
     }
 
@@ -421,4 +430,8 @@
     public void close() {
         destroy();
     }
+
+    public void setSampleResult(boolean sampleResults) {
+        this.mSampleResults = sampleResults;
+    }
 }
diff --git a/src/com/android/nn/benchmark/core/Processor.java b/src/com/android/nn/benchmark/core/Processor.java
index 200bf6b..1cf605c 100644
--- a/src/com/android/nn/benchmark/core/Processor.java
+++ b/src/com/android/nn/benchmark/core/Processor.java
@@ -128,14 +128,23 @@
         mCompilationBenchmarkMaxIterations = maxIterations;
     }
 
-    // Method to retrieve benchmark results for instrumentation tests.
-    // Returns null if the processor is configured to run compilation only
     public BenchmarkResult getInstrumentationResult(
             TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds)
             throws IOException, BenchmarkException {
+        return getInstrumentationResult(t, warmupTimeSeconds, runTimeSeconds, false);
+    }
+
+    // Method to retrieve benchmark results for instrumentation tests.
+    // Returns null if the processor is configured to run compilation only
+    public BenchmarkResult getInstrumentationResult(
+            TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds,
+            boolean sampleResults)
+            throws IOException, BenchmarkException {
         mTest = changeTest(mTest, t);
+        mTest.setSampleResult(sampleResults);
         try {
-            BenchmarkResult result = mRunModelCompilationOnly ? null : getBenchmark(warmupTimeSeconds,
+            BenchmarkResult result = mRunModelCompilationOnly ? null : getBenchmark(
+                    warmupTimeSeconds,
                     runTimeSeconds);
             return result;
         } finally {