blob: 83875d6940f24e9749cb409b4130f8a7ac9fe72b [file] [log] [blame]
/*
* Copyright (C) 2020 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.nn.benchmark.app;
import android.content.Context;
import android.util.Log;
import androidx.test.InstrumentationRegistry;
import com.android.nn.benchmark.core.BenchmarkException;
import com.android.nn.benchmark.core.BenchmarkResult;
import com.android.nn.benchmark.core.NNTestBase;
import com.android.nn.benchmark.core.NnApiDelegationFailure;
import com.android.nn.benchmark.core.Processor;
import com.android.nn.benchmark.core.TestModels;
import com.android.nn.benchmark.core.TfLiteBackend;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.Callable;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
public interface AcceleratorSpecificTestSupport {
String TAG = "AcceleratorTest";
static Optional<TestModels.TestModelEntry> findTestModelRunningOnAccelerator(
Context context, String acceleratorName) throws NnApiDelegationFailure {
for (TestModels.TestModelEntry model : TestModels.modelsList()) {
if (Processor.isTestModelSupportedByAccelerator(context, model, acceleratorName)) {
return Optional.of(model);
}
}
return Optional.empty();
}
static List<TestModels.TestModelEntry> findAllTestModelsRunningOnAccelerator(
Context context, String acceleratorName) throws NnApiDelegationFailure {
List<TestModels.TestModelEntry> result = new ArrayList<>();
for (TestModels.TestModelEntry model : TestModels.modelsList()) {
if (Processor.isTestModelSupportedByAccelerator(context, model, acceleratorName)) {
result.add(model);
}
}
return result;
}
default long ramdomInRange(long min, long max) {
return min + (long) (Math.random() * (max - min));
}
static String getTestParameter(String key, String defaultValue) {
return InstrumentationRegistry.getArguments().getString(key, defaultValue);
}
static boolean getBooleanTestParameter(String key, boolean defaultValue) {
// All instrumentation arguments are passed as String so I have to convert the value here.
return Boolean.parseBoolean(
InstrumentationRegistry.getArguments().getString(key, "" + defaultValue));
}
static final String ACCELERATOR_FILTER_PROPERTY = "nnCrashtestDeviceFilter";
static final String INCLUDE_NNAPI_SELECTED_ACCELERATOR_PROPERTY =
"nnCrashtestIncludeNnapiReference";
static List<String> getTargetAcceleratorNames() {
List<String> accelerators = new ArrayList<>();
String acceleratorFilter = getTestParameter(ACCELERATOR_FILTER_PROPERTY, ".+");
accelerators.addAll(NNTestBase.availableAcceleratorNames().stream().filter(
name -> name.matches(acceleratorFilter)).collect(
Collectors.toList()));
if (getBooleanTestParameter(INCLUDE_NNAPI_SELECTED_ACCELERATOR_PROPERTY, false)) {
accelerators.add(null); // running tests with no specified target accelerator too
}
return accelerators;
}
// This method returns an empty list if no accelerator name has been specified.
static List<String> getOptionalTargetAcceleratorNames() {
List<String> accelerators = new ArrayList<>();
String acceleratorFilter = getTestParameter(ACCELERATOR_FILTER_PROPERTY, "");
if (acceleratorFilter.isEmpty()) {
return Collections.emptyList();
}
accelerators.addAll(NNTestBase.availableAcceleratorNames().stream().filter(
name -> name.matches(acceleratorFilter)).collect(
Collectors.toList()));
if (getBooleanTestParameter(INCLUDE_NNAPI_SELECTED_ACCELERATOR_PROPERTY, false)) {
accelerators.add(null); // running tests with no specified target accelerator too
}
return accelerators;
}
static List<Object[]> perAcceleratorTestConfig(List<Object[]> testConfig) {
return testConfig.stream()
.flatMap(currConfigurationParams -> getTargetAcceleratorNames().stream().map(
accelerator -> {
Object[] result =
Arrays.copyOf(currConfigurationParams,
currConfigurationParams.length + 1);
result[currConfigurationParams.length] = accelerator;
return result;
}))
.collect(Collectors.toList());
}
// Generates a per-accelerator list of test configurations if an accelerator filter has been
// specified. Will return the origin list with an extra `null` parameter for the accelerator
// name if not.
static List<Object[]> maybeAddAcceleratorsToTestConfig(List<Object[]> testConfig) {
return testConfig.stream()
.flatMap(currConfigurationParams -> {
List<String> accelerators = getOptionalTargetAcceleratorNames();
if (accelerators.isEmpty()) {
accelerators = Collections.singletonList((String)null);
}
return accelerators.stream().map(
accelerator -> {
Object[] result =
Arrays.copyOf(currConfigurationParams,
currConfigurationParams.length + 1);
result[currConfigurationParams.length] = accelerator;
return result;
});
})
.collect(Collectors.toList());
}
class DriverLivenessChecker implements Callable<Boolean> {
final Processor mProcessor;
private final AtomicBoolean mRun = new AtomicBoolean(true);
private final TestModels.TestModelEntry mTestModelEntry;
public DriverLivenessChecker(Context context, String acceleratorName,
TestModels.TestModelEntry testModelEntry) {
mProcessor = new Processor(context,
new Processor.Callback() {
@Override
public void onBenchmarkFinish(boolean ok) {
}
@Override
public void onStatusUpdate(int testNumber, int numTests, String modelName) {
}
}, new int[0]);
mProcessor.setTfLiteBackend(TfLiteBackend.NNAPI);
mProcessor.setCompleteInputSet(false);
mProcessor.setNnApiAcceleratorName(acceleratorName);
mProcessor.setUseNnApiSupportLibrary(NNTestBase.shouldUseNnApiSupportLibrary());
mProcessor.setExtractNnApiSupportLibrary(NNTestBase.shouldExtractNnApiSupportLibrary());
mProcessor.setNnApiSupportLibraryVendor(NNTestBase.getNnApiSupportLibraryVendor());
mTestModelEntry = testModelEntry;
}
public void stop() {
mRun.set(false);
}
@Override
public Boolean call() throws Exception {
while (mRun.get()) {
try {
BenchmarkResult modelExecutionResult = mProcessor.getInstrumentationResult(
mTestModelEntry, 0, 3);
if (modelExecutionResult.hasBenchmarkError()) {
Log.e(TAG, String.format("Benchmark failed with message %s",
modelExecutionResult.getBenchmarkError()));
return false;
}
} catch (IOException | BenchmarkException e) {
Log.e(TAG, String.format("Error running model %s", mTestModelEntry.mModelName), e);
return false;
}
}
return true;
}
}
}