blob: ce6a80413bd829c507b8edb82b2f8b984e59daa7 [file] [log] [blame]
/*
* Copyright (C) 2020 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.nn.benchmark.core;
import android.os.Bundle;
import android.os.Parcel;
import android.os.Parcelable;
public class LatencyResult implements Parcelable {
private final static int TIME_FREQ_ARRAY_SIZE = 32;
private float mTotalTimeSec;
private int mIterations;
private float mTimeStdDeviation;
/** Time offset for inference frequency counts */
private float mTimeFreqStartSec;
/** Index time offset for inference frequency counts */
private float mTimeFreqStepSec;
/**
* Array of inference frequency counts.
* Each entry contains inference count for time range:
* [mTimeFreqStartSec + i*mTimeFreqStepSec, mTimeFreqStartSec + (1+i*mTimeFreqStepSec)
*/
private float[] mTimeFreqSec = {};
public LatencyResult(float[] results) {
mIterations = results.length;
mTotalTimeSec = 0.0f;
float maxComputeTimeSec = 0.0f;
float minComputeTimeSec = Float.MAX_VALUE;
for (float result : results) {
mTotalTimeSec += result;
maxComputeTimeSec = Math.max(maxComputeTimeSec, result);
minComputeTimeSec = Math.min(minComputeTimeSec, result);
}
// Calculate standard deviation.
float latencyMean = (mTotalTimeSec / mIterations);
float variance = 0.0f;
for (float result : results) {
float v = (result - latencyMean);
variance += v * v;
}
variance /= mIterations;
mTimeStdDeviation = (float) Math.sqrt(variance);
// Calculate inference frequency/histogram across TIME_FREQ_ARRAY_SIZE buckets.
mTimeFreqStartSec = minComputeTimeSec;
mTimeFreqStepSec = (maxComputeTimeSec - minComputeTimeSec) / (TIME_FREQ_ARRAY_SIZE - 1);
mTimeFreqSec = new float[TIME_FREQ_ARRAY_SIZE];
for (float result : results) {
int bucketIndex = (int) ((result - minComputeTimeSec) / mTimeFreqStepSec);
mTimeFreqSec[bucketIndex] += 1;
}
}
public LatencyResult(Parcel in) {
mTotalTimeSec = in.readFloat();
mIterations = in.readInt();
mTimeStdDeviation = in.readFloat();
mTimeFreqStartSec = in.readFloat();
mTimeFreqStepSec = in.readFloat();
int timeFreqSecLength = in.readInt();
mTimeFreqSec = new float[timeFreqSecLength];
in.readFloatArray(mTimeFreqSec);
}
@Override
public int describeContents() {
return 0;
}
@Override
public void writeToParcel(Parcel dest, int flags) {
dest.writeFloat(mTotalTimeSec);
dest.writeInt(mIterations);
dest.writeFloat(mTimeStdDeviation);
dest.writeFloat(mTimeFreqStartSec);
dest.writeFloat(mTimeFreqStepSec);
dest.writeInt(mTimeFreqSec.length);
dest.writeFloatArray(mTimeFreqSec);
}
public void putToBundle(Bundle results, String prefix) {
// Reported in ms
results.putFloat(prefix + "_avg", getMeanTimeSec() * 1000.0f);
results.putFloat(prefix + "_std_dev", mTimeStdDeviation * 1000.0f);
results.putFloat(prefix + "_total_time", mTotalTimeSec * 1000.0f);
results.putInt(prefix + "_iterations", mIterations);
}
@Override
public String toString() {
return "LatencyResult{"
+ "getMeanTimeSec()=" + getMeanTimeSec()
+ ", mTotalTimeSec=" + mTotalTimeSec
+ ", mIterations=" + mIterations
+ ", mTimeStdDeviation=" + mTimeStdDeviation
+ ", mTimeFreqStartSec=" + mTimeFreqStartSec
+ ", mTimeFreqStepSec=" + mTimeFreqStepSec + "}";
}
public int getIterations() { return mIterations; }
public float getMeanTimeSec() { return mTotalTimeSec / mIterations; }
private float rebase(float v, float baselineSec) {
if (v > 0.001) {
v = baselineSec / v;
}
return v;
}
public String getSummary(float baselineSec) {
java.text.DecimalFormat df = new java.text.DecimalFormat("######.##");
return df.format(rebase(getMeanTimeSec(), baselineSec)) + "X, n=" + mIterations
+ ", μ=" + df.format(getMeanTimeSec() * 1000.0)
+ "ms, σ=" + df.format(mTimeStdDeviation * 1000.0) + "ms";
}
public void appendToCsvLine(StringBuilder sb) {
sb.append(',').append(String.join(",",
String.valueOf(mIterations),
String.valueOf(mTotalTimeSec),
String.valueOf(mTimeFreqStartSec),
String.valueOf(mTimeFreqStepSec),
String.valueOf(mTimeFreqSec.length)));
for (float value : mTimeFreqSec) {
sb.append(',').append(value);
}
}
}