blob: 9f01128d82c6ec35a784b79d35d085d6e7eece89 [file] [log] [blame]
/*
* Copyright (C) 2019 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.nn.crashtest.core.test;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import android.annotation.SuppressLint;
import android.content.Context;
import android.content.Intent;
import android.util.Log;
import com.android.nn.benchmark.core.Processor;
import com.android.nn.crashtest.core.CrashTest;
import com.android.nn.crashtest.core.CrashTestCoordinator.CrashTestIntentInitializer;
import com.android.nn.benchmark.core.TfLiteBackend;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
public class RunModelsInParallel implements CrashTest {
private static final String MODELS = "models";
private static final String DURATION = "duration";
private static final String THREADS = "thread_counts";
private static final String TEST_NAME = "test_name";
private static final String ACCELERATOR_NAME = "accelerator_name";
private static final String IGNORE_UNSUPPORTED_MODELS = "ignore_unsupported_models";
private static final String RUN_MODEL_COMPILATION_ONLY = "run_model_compilation_only";
private static final String MEMORY_MAP_MODEL = "memory_map_model";
private final Set<Processor> activeTests = new HashSet<>();
private final List<Boolean> mTestCompletionResults = Collections.synchronizedList(
new ArrayList<>());
private long mTestDurationMillis = 0;
private int mThreadCount = 0;
private int[] mTestList = new int[0];
private String mTestName;
private String mAcceleratorName;
private boolean mIgnoreUnsupportedModels;
private Context mContext;
private boolean mRunModelCompilationOnly;
private ExecutorService mExecutorService = null;
private CountDownLatch mParallelTestComplete;
private ProgressListener mProgressListener;
private boolean mMmapModel;
static public CrashTestIntentInitializer intentInitializer(int[] models, int threadCount,
Duration duration, String testName, String acceleratorName,
boolean ignoreUnsupportedModels,
boolean runModelCompilationOnly, boolean mmapModel) {
return intent -> {
intent.putExtra(MODELS, models);
intent.putExtra(DURATION, duration.toMillis());
intent.putExtra(THREADS, threadCount);
intent.putExtra(TEST_NAME, testName);
intent.putExtra(ACCELERATOR_NAME, acceleratorName);
intent.putExtra(IGNORE_UNSUPPORTED_MODELS, ignoreUnsupportedModels);
intent.putExtra(RUN_MODEL_COMPILATION_ONLY, runModelCompilationOnly);
intent.putExtra(MEMORY_MAP_MODEL, mmapModel);
};
}
@Override
public void init(Context context, Intent configParams,
Optional<ProgressListener> progressListener) {
mTestList = configParams.getIntArrayExtra(MODELS);
mThreadCount = configParams.getIntExtra(THREADS, 10);
mTestDurationMillis = configParams.getLongExtra(DURATION, 1000 * 60 * 10);
mTestName = configParams.getStringExtra(TEST_NAME);
mAcceleratorName = configParams.getStringExtra(ACCELERATOR_NAME);
mIgnoreUnsupportedModels = mAcceleratorName != null && configParams.getBooleanExtra(
IGNORE_UNSUPPORTED_MODELS, false);
mRunModelCompilationOnly = configParams.getBooleanExtra(RUN_MODEL_COMPILATION_ONLY, false);
mMmapModel = configParams.getBooleanExtra(MEMORY_MAP_MODEL, false);
mContext = context;
mProgressListener = progressListener.orElseGet(() -> (Optional<String> message) -> {
Log.v(CrashTest.TAG, message.orElse("."));
});
mExecutorService = Executors.newFixedThreadPool(mThreadCount);
mTestCompletionResults.clear();
}
@Override
public Optional<String> call() {
mParallelTestComplete = new CountDownLatch(mThreadCount);
for (int i = 0; i < mThreadCount; i++) {
Processor testProcessor = createSubTestRunner(mTestList, i);
activeTests.add(testProcessor);
mExecutorService.submit(testProcessor);
}
return completedSuccessfully();
}
private Processor createSubTestRunner(final int[] testList, final int testIndex) {
final Processor result = new Processor(mContext, new Processor.Callback() {
@SuppressLint("DefaultLocale")
@Override
public void onBenchmarkFinish(boolean ok) {
notifyProgress("Test '%s': Benchmark #%d completed %s", mTestName, testIndex,
ok ? "successfully" : "with failure");
mTestCompletionResults.add(ok);
mParallelTestComplete.countDown();
}
@Override
public void onStatusUpdate(int testNumber, int numTests, String modelName) {
}
}, testList);
result.setTfLiteBackend(TfLiteBackend.NNAPI);
result.setCompleteInputSet(false);
result.setNnApiAcceleratorName(mAcceleratorName);
result.setIgnoreUnsupportedModels(mIgnoreUnsupportedModels);
result.setRunModelCompilationOnly(mRunModelCompilationOnly);
result.setMmapModel(mMmapModel);
return result;
}
private void endTests() {
ExecutorService terminatorsThreadPool = Executors.newFixedThreadPool(activeTests.size());
List<Future<?>> terminationCommands = new ArrayList<>();
for (final Processor test : activeTests) {
// Exit will block until the thread is completed
terminationCommands.add(terminatorsThreadPool.submit(
() -> test.exitWithTimeout(Duration.ofSeconds(20).toMillis())));
}
terminationCommands.forEach(terminationCommand -> {
try {
terminationCommand.get();
} catch (ExecutionException e) {
Log.w(TAG, "Failure while waiting for completion of tests", e);
} catch (InterruptedException e) {
Thread.interrupted();
}
});
}
@SuppressLint("DefaultLocale")
void notifyProgress(String messageFormat, Object... args) {
mProgressListener.testProgress(Optional.of(String.format(messageFormat, args)));
}
// This method blocks until the tests complete and returns true if all tests completed
// successfully
@SuppressLint("DefaultLocale")
private Optional<String> completedSuccessfully() {
try {
boolean testsEnded = mParallelTestComplete.await(mTestDurationMillis, MILLISECONDS);
if (!testsEnded) {
Log.i(TAG,
String.format(
"Test '%s': Tests are not completed (they might have been "
+ "designed to run "
+ "indefinitely. Forcing termination.", mTestName));
endTests();
}
} catch (InterruptedException ignored) {
Thread.currentThread().interrupt();
}
final long failedTestCount = mTestCompletionResults.stream().filter(
testResult -> !testResult).count();
if (failedTestCount > 0) {
String failureMsg = String.format("Test '%s': %d out of %d test failed", mTestName,
failedTestCount,
mTestCompletionResults.size());
Log.w(CrashTest.TAG, failureMsg);
return failure(failureMsg);
} else {
Log.i(CrashTest.TAG,
String.format("Test '%s': Test completed successfully", mTestName));
return success();
}
}
}