| /* |
| * Copyright (C) 2017 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package com.android.nn.benchmark.core; |
| |
| import android.annotation.SuppressLint; |
| import android.content.Context; |
| import android.content.res.AssetManager; |
| import android.os.Build; |
| import android.util.Log; |
| import android.util.Pair; |
| import android.widget.TextView; |
| |
| import java.io.File; |
| import java.io.FileOutputStream; |
| import java.io.IOException; |
| import java.io.InputStream; |
| import java.util.ArrayList; |
| import java.util.Collections; |
| import java.util.List; |
| import java.util.Optional; |
| import java.util.Random; |
| import java.util.stream.Collectors; |
| |
| public class NNTestBase implements AutoCloseable { |
| protected static final String TAG = "NN_TESTBASE"; |
| |
| // Used to load the 'native-lib' library on application startup. |
| static { |
| System.loadLibrary("nnbenchmark_jni"); |
| } |
| |
| // Does the device has any NNAPI accelerator? |
| // We only consider a real device, not 'nnapi-reference'. |
| public static native boolean hasAccelerator(); |
| |
| /** |
| * Fills resultList with the name of the available NNAPI accelerators |
| * |
| * @return False if any error occurred, true otherwise |
| */ |
| private static native boolean getAcceleratorNames(List<String> resultList); |
| |
| private synchronized native long initModel( |
| String modelFileName, |
| boolean useNNApi, |
| boolean enableIntermediateTensorsDump, |
| String nnApiDeviceName, |
| boolean mmapModel, |
| String nnApiCacheDir) throws NnApiDelegationFailure; |
| |
| private synchronized native void destroyModel(long modelHandle); |
| |
| private synchronized native boolean resizeInputTensors(long modelHandle, int[] inputShape); |
| |
| private synchronized native boolean runBenchmark(long modelHandle, |
| List<InferenceInOutSequence> inOutList, |
| List<InferenceResult> resultList, |
| int inferencesSeqMaxCount, |
| float timeoutSec, |
| int flags); |
| |
| private synchronized native CompilationBenchmarkResult runCompilationBenchmark( |
| long modelHandle, int maxNumIterations, float warmupTimeoutSec, float runTimeoutSec); |
| |
| private synchronized native void dumpAllLayers( |
| long modelHandle, |
| String dumpPath, |
| List<InferenceInOutSequence> inOutList); |
| |
| public static List<String> availableAcceleratorNames() { |
| List<String> availableAccelerators = new ArrayList<>(); |
| if (NNTestBase.getAcceleratorNames(availableAccelerators)) { |
| return availableAccelerators.stream().filter( |
| acceleratorName -> !acceleratorName.equalsIgnoreCase( |
| "nnapi-reference")).collect(Collectors.toList()); |
| } else { |
| Log.e(TAG, "Unable to retrieve accelerator names!!"); |
| return Collections.EMPTY_LIST; |
| } |
| } |
| |
| /** Discard inference output in inference results. */ |
| public static final int FLAG_DISCARD_INFERENCE_OUTPUT = 1 << 0; |
| /** |
| * Do not expect golden outputs with inference inputs. |
| * |
| * Useful in cases where there's no straightforward golden output values |
| * for the benchmark. This will also skip calculating basic (golden |
| * output based) error metrics. |
| */ |
| public static final int FLAG_IGNORE_GOLDEN_OUTPUT = 1 << 1; |
| |
| |
| protected Context mContext; |
| protected TextView mText; |
| private final String mModelName; |
| private final String mModelFile; |
| private long mModelHandle; |
| private final int[] mInputShape; |
| private final InferenceInOutSequence.FromAssets[] mInputOutputAssets; |
| private final InferenceInOutSequence.FromDataset[] mInputOutputDatasets; |
| private final EvaluatorConfig mEvaluatorConfig; |
| private EvaluatorInterface mEvaluator; |
| private boolean mHasGoldenOutputs; |
| private boolean mUseNNApi = false; |
| private boolean mEnableIntermediateTensorsDump = false; |
| private final int mMinSdkVersion; |
| private Optional<String> mNNApiDeviceName = Optional.empty(); |
| private boolean mMmapModel = false; |
| // Path where the current model has been stored for execution |
| private String mTemporaryModelFilePath; |
| |
| public NNTestBase(String modelName, String modelFile, int[] inputShape, |
| InferenceInOutSequence.FromAssets[] inputOutputAssets, |
| InferenceInOutSequence.FromDataset[] inputOutputDatasets, |
| EvaluatorConfig evaluator, int minSdkVersion) { |
| if (inputOutputAssets == null && inputOutputDatasets == null) { |
| throw new IllegalArgumentException( |
| "Neither inputOutputAssets or inputOutputDatasets given - no inputs"); |
| } |
| if (inputOutputAssets != null && inputOutputDatasets != null) { |
| throw new IllegalArgumentException( |
| "Both inputOutputAssets or inputOutputDatasets given. Only one" + |
| "supported at once."); |
| } |
| mModelName = modelName; |
| mModelFile = modelFile; |
| mInputShape = inputShape; |
| mInputOutputAssets = inputOutputAssets; |
| mInputOutputDatasets = inputOutputDatasets; |
| mModelHandle = 0; |
| mEvaluatorConfig = evaluator; |
| mMinSdkVersion = minSdkVersion; |
| } |
| |
| public void useNNApi() { |
| useNNApi(true); |
| } |
| |
| public void useNNApi(boolean value) { |
| mUseNNApi = value; |
| } |
| |
| public void enableIntermediateTensorsDump() { |
| enableIntermediateTensorsDump(true); |
| } |
| |
| public void enableIntermediateTensorsDump(boolean value) { |
| mEnableIntermediateTensorsDump = value; |
| } |
| |
| public void setNNApiDeviceName(String value) { |
| if (!mUseNNApi) { |
| Log.e(TAG, "Setting device name has no effect when not using NNAPI"); |
| } |
| mNNApiDeviceName = Optional.ofNullable(value); |
| } |
| |
| public void setMmapModel(boolean value) { |
| mMmapModel = value; |
| } |
| |
| public final boolean setupModel(Context ipcxt) throws IOException, NnApiDelegationFailure { |
| mContext = ipcxt; |
| if (mTemporaryModelFilePath != null) { |
| deleteOrWarn(mTemporaryModelFilePath); |
| } |
| mTemporaryModelFilePath = copyAssetToFile(); |
| String nnApiCacheDir = mContext.getCodeCacheDir().toString(); |
| mModelHandle = initModel( |
| mTemporaryModelFilePath, mUseNNApi, mEnableIntermediateTensorsDump, |
| mNNApiDeviceName.orElse(null), mMmapModel, nnApiCacheDir); |
| if (mModelHandle == 0) { |
| Log.e(TAG, "Failed to init the model"); |
| return false; |
| } |
| resizeInputTensors(mModelHandle, mInputShape); |
| |
| if (mEvaluatorConfig != null) { |
| mEvaluator = mEvaluatorConfig.createEvaluator(mContext.getAssets()); |
| } |
| return true; |
| } |
| |
| public String getTestInfo() { |
| return mModelName; |
| } |
| |
| public EvaluatorInterface getEvaluator() { |
| return mEvaluator; |
| } |
| |
| public void checkSdkVersion() throws UnsupportedSdkException { |
| if (mMinSdkVersion > 0 && Build.VERSION.SDK_INT < mMinSdkVersion) { |
| throw new UnsupportedSdkException("SDK version not supported. Mininum required: " + |
| mMinSdkVersion + ", current version: " + Build.VERSION.SDK_INT); |
| } |
| } |
| |
| private void deleteOrWarn(String path) { |
| if (!new File(path).delete()) { |
| Log.w(TAG, String.format( |
| "Unable to delete file '%s'. This might cause device to run out of space.", |
| path)); |
| } |
| } |
| |
| |
| private List<InferenceInOutSequence> getInputOutputAssets() throws IOException { |
| // TODO: Caching, don't read inputs for every inference |
| List<InferenceInOutSequence> inOutList = |
| getInputOutputAssets(mContext, mInputOutputAssets, mInputOutputDatasets); |
| |
| Boolean lastGolden = null; |
| for (InferenceInOutSequence sequence : inOutList) { |
| mHasGoldenOutputs = sequence.hasGoldenOutput(); |
| if (lastGolden == null) { |
| lastGolden = mHasGoldenOutputs; |
| } else { |
| if (lastGolden != mHasGoldenOutputs) { |
| throw new IllegalArgumentException( |
| "Some inputs for " + mModelName + " have outputs while some don't."); |
| } |
| } |
| } |
| return inOutList; |
| } |
| |
| public static List<InferenceInOutSequence> getInputOutputAssets(Context context, |
| InferenceInOutSequence.FromAssets[] inputOutputAssets, |
| InferenceInOutSequence.FromDataset[] inputOutputDatasets) throws IOException { |
| // TODO: Caching, don't read inputs for every inference |
| List<InferenceInOutSequence> inOutList = new ArrayList<>(); |
| if (inputOutputAssets != null) { |
| for (InferenceInOutSequence.FromAssets ioAsset : inputOutputAssets) { |
| inOutList.add(ioAsset.readAssets(context.getAssets())); |
| } |
| } |
| if (inputOutputDatasets != null) { |
| for (InferenceInOutSequence.FromDataset dataset : inputOutputDatasets) { |
| inOutList.addAll(dataset.readDataset(context.getAssets(), context.getCacheDir())); |
| } |
| } |
| |
| return inOutList; |
| } |
| |
| public int getDefaultFlags() { |
| int flags = 0; |
| if (!mHasGoldenOutputs) { |
| flags = flags | FLAG_IGNORE_GOLDEN_OUTPUT; |
| } |
| if (mEvaluator == null) { |
| flags = flags | FLAG_DISCARD_INFERENCE_OUTPUT; |
| } |
| return flags; |
| } |
| |
| public void dumpAllLayers(File dumpDir, int inputAssetIndex, int inputAssetSize) |
| throws IOException { |
| if (!dumpDir.exists() || !dumpDir.isDirectory()) { |
| throw new IllegalArgumentException("dumpDir doesn't exist or is not a directory"); |
| } |
| if (!mEnableIntermediateTensorsDump) { |
| throw new IllegalStateException("mEnableIntermediateTensorsDump is " + |
| "set to false, impossible to proceed"); |
| } |
| |
| List<InferenceInOutSequence> ios = getInputOutputAssets(); |
| dumpAllLayers(mModelHandle, dumpDir.toString(), |
| ios.subList(inputAssetIndex, inputAssetSize)); |
| } |
| |
| public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runInferenceOnce() |
| throws IOException, BenchmarkException { |
| List<InferenceInOutSequence> ios = getInputOutputAssets(); |
| int flags = getDefaultFlags(); |
| Pair<List<InferenceInOutSequence>, List<InferenceResult>> output = |
| runBenchmark(ios, 1, Float.MAX_VALUE, flags); |
| return output; |
| } |
| |
| public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmark(float timeoutSec) |
| throws IOException, BenchmarkException { |
| // Run as many as possible before timeout. |
| int flags = getDefaultFlags(); |
| return runBenchmark(getInputOutputAssets(), 0xFFFFFFF, timeoutSec, flags); |
| } |
| |
| /** Run through whole input set (once or mutliple times). */ |
| public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmarkCompleteInputSet( |
| int setRepeat, |
| float timeoutSec) |
| throws IOException, BenchmarkException { |
| int flags = getDefaultFlags(); |
| List<InferenceInOutSequence> ios = getInputOutputAssets(); |
| int totalSequenceInferencesCount = ios.size() * setRepeat; |
| int extpectedResults = 0; |
| for (InferenceInOutSequence iosSeq : ios) { |
| extpectedResults += iosSeq.size(); |
| } |
| extpectedResults *= setRepeat; |
| |
| Pair<List<InferenceInOutSequence>, List<InferenceResult>> result = |
| runBenchmark(ios, totalSequenceInferencesCount, timeoutSec, |
| flags); |
| if (result.second.size() != extpectedResults) { |
| // We reached a timeout or failed to evaluate whole set for other reason, abort. |
| @SuppressLint("DefaultLocale") |
| final String errorMsg = String.format( |
| "Failed to evaluate complete input set, in %d seconds expected: %d, received:" |
| + " %d", |
| timeoutSec, extpectedResults, result.second.size()); |
| Log.w(TAG, errorMsg); |
| throw new IllegalStateException(errorMsg); |
| } |
| return result; |
| } |
| |
| public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmark( |
| List<InferenceInOutSequence> inOutList, |
| int inferencesSeqMaxCount, |
| float timeoutSec, |
| int flags) |
| throws IOException, BenchmarkException { |
| if (mModelHandle == 0) { |
| throw new UnsupportedModelException("Unsupported model"); |
| } |
| List<InferenceResult> resultList = new ArrayList<>(); |
| if (!runBenchmark(mModelHandle, inOutList, resultList, inferencesSeqMaxCount, |
| timeoutSec, flags)) { |
| throw new BenchmarkException("Failed to run benchmark"); |
| } |
| return new Pair<List<InferenceInOutSequence>, List<InferenceResult>>( |
| inOutList, resultList); |
| } |
| |
| public CompilationBenchmarkResult runCompilationBenchmark(float warmupTimeoutSec, |
| float runTimeoutSec, int maxIterations) throws IOException, BenchmarkException { |
| if (mModelHandle == 0) { |
| throw new UnsupportedModelException("Unsupported model"); |
| } |
| CompilationBenchmarkResult result = runCompilationBenchmark( |
| mModelHandle, maxIterations, warmupTimeoutSec, runTimeoutSec); |
| if (result == null) { |
| throw new BenchmarkException("Failed to run compilation benchmark"); |
| } |
| return result; |
| } |
| |
| public void destroy() { |
| if (mModelHandle != 0) { |
| destroyModel(mModelHandle); |
| mModelHandle = 0; |
| } |
| if (mTemporaryModelFilePath != null) { |
| deleteOrWarn(mTemporaryModelFilePath); |
| mTemporaryModelFilePath = null; |
| } |
| } |
| |
| private final Random mRandom = new Random(System.currentTimeMillis()); |
| |
| // We need to copy it to cache dir, so that TFlite can load it directly. |
| private String copyAssetToFile() throws IOException { |
| @SuppressLint("DefaultLocale") |
| String outFileName = |
| String.format("%s/%s-%d-%d.tflite", mContext.getCacheDir().getAbsolutePath(), |
| mModelFile, |
| Thread.currentThread().getId(), mRandom.nextInt(10000)); |
| |
| copyAssetToFile(mContext, mModelFile + ".tflite", outFileName); |
| return outFileName; |
| } |
| |
| public static boolean copyModelToFile(Context context, String modelFileName, File targetFile) |
| throws IOException { |
| if (!targetFile.exists() && !targetFile.createNewFile()) { |
| Log.w(TAG, String.format("Unable to create file %s", targetFile.getAbsolutePath())); |
| return false; |
| } |
| NNTestBase.copyAssetToFile(context, modelFileName, targetFile.getAbsolutePath()); |
| return true; |
| } |
| |
| public static void copyAssetToFile(Context context, String modelAssetName, String targetPath) |
| throws IOException { |
| AssetManager assetManager = context.getAssets(); |
| try { |
| File outFile = new File(targetPath); |
| |
| try (InputStream in = assetManager.open(modelAssetName); |
| FileOutputStream out = new FileOutputStream(outFile)) { |
| byte[] byteBuffer = new byte[1024]; |
| int readBytes = -1; |
| while ((readBytes = in.read(byteBuffer)) != -1) { |
| out.write(byteBuffer, 0, readBytes); |
| } |
| } |
| } catch (IOException e) { |
| Log.e(TAG, "Failed to copy asset file: " + modelAssetName, e); |
| throw e; |
| } |
| } |
| |
| @Override |
| public void close() { |
| destroy(); |
| } |
| } |