blob: 5e4cad64c400311b52041a41120f7191d0b499d5 [file] [log] [blame]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.nn.benchmark.core;
import android.os.Bundle;
import android.os.Parcel;
import android.os.Parcelable;
import android.text.TextUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class BenchmarkResult implements Parcelable {
public final static String BACKEND_TFLITE_NNAPI = "TFLite_NNAPI";
public final static String BACKEND_TFLITE_CPU = "TFLite_CPU";
private final static int TIME_FREQ_ARRAY_SIZE = 32;
private float mTotalTimeSec;
private float mSumOfMSEs;
private float mMaxSingleError;
private int mIterations;
private float mTimeStdDeviation;
private String mTestInfo;
private int mNumberOfEvaluatorResults;
private String[] mEvaluatorKeys = {};
private float[] mEvaluatorResults = {};
/** Type of backend used for inference */
private String mBackendType;
/** Time offset for inference frequency counts */
private float mTimeFreqStartSec;
/** Index time offset for inference frequency counts */
private float mTimeFreqStepSec;
/**
* Array of inference frequency counts.
* Each entry contains inference count for time range:
* [mTimeFreqStartSec + i*mTimeFreqStepSec, mTimeFreqStartSec + (1+i*mTimeFreqStepSec)
*/
private float[] mTimeFreqSec = {};
/** Size of test set using for inference */
private int mTestSetSize;
/** List of validation errors */
private String[] mValidationErrors = {};
/** Error that prevents the benchmark from running, e.g. SDK version not supported. */
private String mBenchmarkError;
public BenchmarkResult(float totalTimeSec, int iterations, float timeVarianceSec,
float sumOfMSEs, float maxSingleError, String testInfo,
String[] evaluatorKeys, float[] evaluatorResults,
float timeFreqStartSec, float timeFreqStepSec, float[] timeFreqSec,
String backendType, int testSetSize, String[] validationErrors) {
mTotalTimeSec = totalTimeSec;
mSumOfMSEs = sumOfMSEs;
mMaxSingleError = maxSingleError;
mIterations = iterations;
mTimeStdDeviation = timeVarianceSec;
mTestInfo = testInfo;
mTimeFreqStartSec = timeFreqStartSec;
mTimeFreqStepSec = timeFreqStepSec;
mTimeFreqSec = timeFreqSec;
mBackendType = backendType;
mTestSetSize = testSetSize;
if (validationErrors == null) {
mValidationErrors = new String[0];
} else {
mValidationErrors = validationErrors;
}
if (evaluatorKeys == null) {
mEvaluatorKeys = new String[0];
} else {
mEvaluatorKeys = evaluatorKeys;
}
if (evaluatorResults == null) {
mEvaluatorResults = new float[0];
} else {
mEvaluatorResults = evaluatorResults;
}
if (mEvaluatorResults.length != mEvaluatorKeys.length) {
throw new IllegalArgumentException("Different number of evaluator keys vs values");
}
mNumberOfEvaluatorResults = mEvaluatorResults.length;
}
public BenchmarkResult(String benchmarkError) {
mBenchmarkError = benchmarkError;
}
public boolean hasValidationErrors() {
return mValidationErrors.length > 0;
}
protected BenchmarkResult(Parcel in) {
mTotalTimeSec = in.readFloat();
mSumOfMSEs = in.readFloat();
mMaxSingleError = in.readFloat();
mIterations = in.readInt();
mTimeStdDeviation = in.readFloat();
mTestInfo = in.readString();
mNumberOfEvaluatorResults = in.readInt();
mEvaluatorKeys = new String[mNumberOfEvaluatorResults];
in.readStringArray(mEvaluatorKeys);
mEvaluatorResults = new float[mNumberOfEvaluatorResults];
in.readFloatArray(mEvaluatorResults);
if (mEvaluatorResults.length != mEvaluatorKeys.length) {
throw new IllegalArgumentException("Different number of evaluator keys vs values");
}
mTimeFreqStartSec = in.readFloat();
mTimeFreqStepSec = in.readFloat();
int timeFreqSecLength = in.readInt();
mTimeFreqSec = new float[timeFreqSecLength];
in.readFloatArray(mTimeFreqSec);
mBackendType = in.readString();
mTestSetSize = in.readInt();
int validationsErrorsSize = in.readInt();
mValidationErrors = new String[validationsErrorsSize];
in.readStringArray(mValidationErrors);
mBenchmarkError = in.readString();
}
@Override
public int describeContents() {
return 0;
}
@Override
public void writeToParcel(Parcel dest, int flags) {
dest.writeFloat(mTotalTimeSec);
dest.writeFloat(mSumOfMSEs);
dest.writeFloat(mMaxSingleError);
dest.writeInt(mIterations);
dest.writeFloat(mTimeStdDeviation);
dest.writeString(mTestInfo);
dest.writeInt(mNumberOfEvaluatorResults);
dest.writeStringArray(mEvaluatorKeys);
dest.writeFloatArray(mEvaluatorResults);
dest.writeFloat(mTimeFreqStartSec);
dest.writeFloat(mTimeFreqStepSec);
dest.writeInt(mTimeFreqSec.length);
dest.writeFloatArray(mTimeFreqSec);
dest.writeString(mBackendType);
dest.writeInt(mTestSetSize);
dest.writeInt(mValidationErrors.length);
dest.writeStringArray(mValidationErrors);
dest.writeString(mBenchmarkError);
}
@SuppressWarnings("unused")
public static final Parcelable.Creator<BenchmarkResult> CREATOR =
new Parcelable.Creator<BenchmarkResult>() {
@Override
public BenchmarkResult createFromParcel(Parcel in) {
return new BenchmarkResult(in);
}
@Override
public BenchmarkResult[] newArray(int size) {
return new BenchmarkResult[size];
}
};
public float getMeanTimeSec() {
return mTotalTimeSec / mIterations;
}
@Override
public String toString() {
if (!TextUtils.isEmpty(mBenchmarkError)) {
return mBenchmarkError;
}
StringBuilder result = new StringBuilder("BenchmarkResult{" +
"mTestInfo='" + mTestInfo + '\'' +
", getMeanTimeSec()=" + getMeanTimeSec() +
", mTotalTimeSec=" + mTotalTimeSec +
", mSumOfMSEs=" + mSumOfMSEs +
", mMaxSingleErrors=" + mMaxSingleError +
", mIterations=" + mIterations +
", mTimeStdDeviation=" + mTimeStdDeviation +
", mTimeFreqStartSec=" + mTimeFreqStartSec +
", mTimeFreqStepSec=" + mTimeFreqStepSec +
", mBackendType=" + mBackendType +
", mTestSetSize=" + mTestSetSize);
for (int i = 0; i < mEvaluatorKeys.length; i++) {
result.append(", ").append(mEvaluatorKeys[i]).append("=").append(mEvaluatorResults[i]);
}
result.append(", mValidationErrors=[");
for (int i = 0; i < mValidationErrors.length; i++) {
result.append(mValidationErrors[i]);
if (i < mValidationErrors.length - 1) {
result.append(",");
}
}
result.append("]");
result.append('}');
return result.toString();
}
public String getSummary(float baselineSec) {
if (!TextUtils.isEmpty(mBenchmarkError)) {
return mBenchmarkError;
}
java.text.DecimalFormat df = new java.text.DecimalFormat("######.##");
return df.format(rebase(getMeanTimeSec(), baselineSec)) +
"X, n=" + mIterations + ", μ=" + df.format(getMeanTimeSec() * 1000.0)
+ "ms, σ=" + df.format(mTimeStdDeviation * 1000.0) + "ms";
}
public Bundle toBundle(String testName) {
Bundle results = new Bundle();
if (!TextUtils.isEmpty(mBenchmarkError)) {
results.putString(testName + "_error", mBenchmarkError);
return results;
}
// Reported in ms
results.putFloat(testName + "_avg", getMeanTimeSec() * 1000.0f);
results.putFloat(testName + "_std_dev", mTimeStdDeviation * 1000.0f);
results.putFloat(testName + "_total_time", mTotalTimeSec * 1000.0f);
results.putFloat(testName + "_mean_square_error", mSumOfMSEs / mIterations);
results.putFloat(testName + "_max_single_error", mMaxSingleError);
results.putInt(testName + "_iterations", mIterations);
for (int i = 0; i < mEvaluatorKeys.length; i++) {
results.putFloat(testName + "_" + mEvaluatorKeys[i],
mEvaluatorResults[i]);
}
return results;
}
public String toCsvLine() {
if (!TextUtils.isEmpty(mBenchmarkError)) {
return "";
}
StringBuilder sb = new StringBuilder();
sb.append(String.join(",",
mBackendType,
String.valueOf(mIterations),
String.valueOf(mTotalTimeSec),
String.valueOf(mMaxSingleError),
String.valueOf(mTestSetSize),
String.valueOf(mTimeFreqStartSec),
String.valueOf(mTimeFreqStepSec),
String.valueOf(mEvaluatorKeys.length),
String.valueOf(mTimeFreqSec.length),
String.valueOf(mValidationErrors.length)));
sb.append(',');
sb.append(String.join(",", Arrays.asList(mEvaluatorKeys)));
for (int i = 0; i < mEvaluatorKeys.length; ++i) {
sb.append(',').append(mEvaluatorResults[i]);
}
for (float value : mTimeFreqSec) {
sb.append(',').append(value);
}
for (String validationError : mValidationErrors) {
sb.append(',').append(validationError.replace(',', ' '));
}
sb.append('\n');
return sb.toString();
}
float rebase(float v, float baselineSec) {
if (v > 0.001) {
v = baselineSec / v;
}
return v;
}
public static BenchmarkResult fromInferenceResults(
String testInfo,
String backendType,
List<InferenceInOutSequence> inferenceInOuts,
List<InferenceResult> inferenceResults,
EvaluatorInterface evaluator) {
float totalTime = 0;
int iterations = 0;
float sumOfMSEs = 0;
float maxSingleError = 0;
float maxComputeTimeSec = 0.0f;
float minComputeTimeSec = Float.MAX_VALUE;
for (InferenceResult iresult : inferenceResults) {
iterations++;
totalTime += iresult.mComputeTimeSec;
if (iresult.mMeanSquaredErrors != null) {
for (float mse : iresult.mMeanSquaredErrors) {
sumOfMSEs += mse;
}
}
if (iresult.mMaxSingleErrors != null) {
for (float mse : iresult.mMaxSingleErrors) {
if (mse > maxSingleError) {
maxSingleError = mse;
}
}
}
if (maxComputeTimeSec < iresult.mComputeTimeSec) {
maxComputeTimeSec = iresult.mComputeTimeSec;
}
if (minComputeTimeSec > iresult.mComputeTimeSec) {
minComputeTimeSec = iresult.mComputeTimeSec;
}
}
float inferenceMean = (totalTime / iterations);
float variance = 0.0f;
for (InferenceResult iresult : inferenceResults) {
float v = (iresult.mComputeTimeSec - inferenceMean);
variance += v * v;
}
variance /= iterations;
String[] evaluatorKeys = null;
float[] evaluatorResults = null;
String[] validationErrors = null;
if (evaluator != null) {
ArrayList<String> keys = new ArrayList<String>();
ArrayList<Float> results = new ArrayList<Float>();
ArrayList<String> validationErrorsList = new ArrayList<>();
evaluator.EvaluateAccuracy(inferenceInOuts, inferenceResults, keys, results,
validationErrorsList);
evaluatorKeys = new String[keys.size()];
evaluatorKeys = keys.toArray(evaluatorKeys);
evaluatorResults = new float[results.size()];
for (int i = 0; i < evaluatorResults.length; i++) {
evaluatorResults[i] = results.get(i).floatValue();
}
validationErrors = new String[validationErrorsList.size()];
validationErrorsList.toArray(validationErrors);
}
// Calculate inference frequency/histogram across TIME_FREQ_ARRAY_SIZE buckets.
float[] timeFreqSec = new float[TIME_FREQ_ARRAY_SIZE];
float stepSize = (maxComputeTimeSec - minComputeTimeSec) / (TIME_FREQ_ARRAY_SIZE - 1);
for (InferenceResult iresult : inferenceResults) {
timeFreqSec[(int) ((iresult.mComputeTimeSec - minComputeTimeSec) / stepSize)] += 1;
}
// Calc test set size
int testSetSize = 0;
for (InferenceInOutSequence iios : inferenceInOuts) {
testSetSize += iios.size();
}
return new BenchmarkResult(totalTime, iterations, (float) Math.sqrt(variance),
sumOfMSEs, maxSingleError, testInfo, evaluatorKeys, evaluatorResults,
minComputeTimeSec, stepSize, timeFreqSec, backendType, testSetSize,
validationErrors);
}
}