blob: 3ef56a020f41d3c994777e98fff131f5f934613b [file] [log] [blame]
/*
* Copyright (C) 2020 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.nn.benchmark.app;
import android.annotation.SuppressLint;
import android.content.Context;
import android.content.Intent;
import android.os.RemoteException;
import android.test.ActivityInstrumentationTestCase2;
import android.test.UiThreadTest;
import android.test.suitebuilder.annotation.LargeTest;
import android.util.Log;
import androidx.test.InstrumentationRegistry;
import com.android.nn.benchmark.core.BenchmarkException;
import com.android.nn.benchmark.core.NNTestBase;
import com.android.nn.benchmark.core.Processor;
import com.android.nn.benchmark.core.TestModels;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TestName;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameters;
import java.io.IOException;
import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.IntStream;
@RunWith(Parameterized.class)
public class NNClientEarlyTerminationTest extends
ActivityInstrumentationTestCase2<NNParallelTestActivity> {
private static final String TAG = "NNClientEarlyTermination";
private static final Duration MAX_SEPARATE_PROCESS_EXECUTION_TIME = Duration.ofSeconds(70);
public static final int NNAPI_CLIENTS_COUNT = 4;
private final ExecutorService mDriverLivenessValidationExecutor =
Executors.newSingleThreadExecutor();
private final String mAcceleratorName;
@Rule
public TestName mTestName = new TestName();
public NNClientEarlyTerminationTest(String acceleratorName) {
super(NNParallelTestActivity.class);
mAcceleratorName = acceleratorName;
}
@Parameters(name = "Accelerator({0})")
public static Iterable<String> targetAccelerators() {
return NNTestBase.availableAcceleratorNames();
}
private static Optional<TestModels.TestModelEntry> findTestModelRunningOnAccelerator(
Context context, String acceleratorName) {
return TestModels.modelsList().stream()
.map(model -> new TestModels.TestModelEntry(
model.mModelName,
model.mBaselineSec,
model.mInputShape,
model.mInOutAssets,
model.mInOutDatasets,
model.mTestName,
model.mModelFile,
null, // Disable evaluation.
model.mMinSdkVersion)
).filter(
model -> Processor.isTestModelSupportedByAccelerator(
context,
model, acceleratorName)).findAny();
}
@Before
@Override
public void setUp() {
injectInstrumentation(InstrumentationRegistry.getInstrumentation());
final Intent runSomeInferencesInASeparateProcess = compileSupportedModelsOnNThreadsFor(
NNAPI_CLIENTS_COUNT,
MAX_SEPARATE_PROCESS_EXECUTION_TIME);
setActivityIntent(runSomeInferencesInASeparateProcess);
}
private long ramdomInRange(long min, long max) {
return min + (long) (Math.random() * (max - min));
}
@Test
@LargeTest
@UiThreadTest
public void testDriverDoesNotFailWithParallelThreads()
throws ExecutionException, InterruptedException, RemoteException {
final NNParallelTestActivity activity = getActivity();
Optional<TestModels.TestModelEntry> modelForLivenessTest =
findTestModelRunningOnAccelerator(activity, mAcceleratorName);
assertTrue("No model available to be run on accelerator " + mAcceleratorName,
modelForLivenessTest.isPresent());
final DriverLivenessChecker driverLivenessChecker = new DriverLivenessChecker(activity,
mAcceleratorName, modelForLivenessTest.get());
Future<Boolean> driverDidNotCrash = mDriverLivenessValidationExecutor.submit(
driverLivenessChecker);
// Causing failure before tests would end on their own.
final long maxTerminationTime = MAX_SEPARATE_PROCESS_EXECUTION_TIME.toMillis() / 2;
final long minTerminationTime = MAX_SEPARATE_PROCESS_EXECUTION_TIME.toMillis() / 4;
Thread.sleep(ramdomInRange(minTerminationTime, maxTerminationTime));
try {
activity.killTestProcess();
} catch (RemoteException e) {
driverLivenessChecker.stop();
throw e;
}
NNParallelTestActivity.TestResult testResult = activity.testResult();
driverLivenessChecker.stop();
assertEquals("Remote process is expected to be killed",
NNParallelTestActivity.TestResult.CRASH,
testResult);
assertTrue("Driver shouldn't crash if a client process is terminated",
driverDidNotCrash.get());
}
private Intent compileSupportedModelsOnNThreadsFor(int threadCount, Duration testDuration) {
Intent intent = new Intent();
intent.putExtra(
NNParallelTestActivity.EXTRA_TEST_LIST, IntStream.range(0,
TestModels.modelsList().size()).toArray());
intent.putExtra(NNParallelTestActivity.EXTRA_THREAD_COUNT, threadCount);
intent.putExtra(NNParallelTestActivity.EXTRA_TEST_DURATION_MILLIS, testDuration.toMillis());
intent.putExtra(NNParallelTestActivity.EXTRA_RUN_IN_SEPARATE_PROCESS, true);
intent.putExtra(NNParallelTestActivity.EXTRA_TEST_NAME, mTestName.getMethodName());
intent.putExtra(NNParallelTestActivity.EXTRA_ACCELERATOR_NAME, mAcceleratorName);
intent.putExtra(NNParallelTestActivity.EXTRA_IGNORE_UNSUPPORTED_MODELS, true);
intent.putExtra(NNParallelTestActivity.EXTRA_RUN_MODEL_COMPILATION_ONLY, true);
return intent;
}
static class DriverLivenessChecker implements Callable<Boolean> {
final Processor mProcessor;
private final AtomicBoolean mRun = new AtomicBoolean(true);
private final TestModels.TestModelEntry mTestModelEntry;
public DriverLivenessChecker(Context context, String acceleratorName,
TestModels.TestModelEntry testModelEntry) {
mProcessor = new Processor(context,
new Processor.Callback() {
@SuppressLint("DefaultLocale")
@Override
public void onBenchmarkFinish(boolean ok) {
}
@Override
public void onStatusUpdate(int testNumber, int numTests, String modelName) {
}
}, new int[0]);
mProcessor.setUseNNApi(true);
mProcessor.setCompleteInputSet(false);
mProcessor.setNnApiAcceleratorName(acceleratorName);
mTestModelEntry = testModelEntry;
}
public void stop() {
mRun.set(false);
}
@Override
public Boolean call() throws Exception {
while (mRun.get()) {
try {
mProcessor.getInstrumentationResult(mTestModelEntry, 0, 3);
} catch (IOException | BenchmarkException e) {
Log.e(TAG, String.format("Error running model %s", mTestModelEntry.mModelName));
}
}
return true;
}
}
}