blob: e90f3fa54f098056fb920b33fd08122eea970d47 [file] [log] [blame]
package com.android.nn.benchmark.evaluators;
import com.android.nn.benchmark.util.SequenceUtils;
import java.util.List;
/**
* Inference evaluator for the ASR model.
*
* This validates that the Phone Error Rate (PER) is within the limit.
*/
public class PhoneErrorRate extends BaseSequenceEvaluator {
static private final float PHONE_ERROR_RATE_LIMIT = 5f; // 5%
private float mMaxPER = 0f;
@Override
protected void EvaluateSequenceAccuracy(float[][] outputs, float[][] expectedOutputs,
List<String> outValidationErrors) {
float per = calculatePER(outputs, expectedOutputs);
if (per > PHONE_ERROR_RATE_LIMIT) {
outValidationErrors.add("Phone error rate exceeded the limit: " + per);
}
mMaxPER = Math.max(mMaxPER, per);
}
@Override
protected void AddValidationResult(List<String> keys, List<Float> values) {
keys.add("max_phone_error_rate");
values.add(mMaxPER);
}
/** Calculates Phone Error Rate in percent. */
private static float calculatePER(float[][] outputs, float[][] expectedOutputs) {
int inferenceCount = outputs.length;
int[] outputPhones = new int[inferenceCount];
int[] expectedOutputPhones = new int[inferenceCount];
for (int i = 0; i < inferenceCount; ++i) {
outputPhones[i] = SequenceUtils.indexOfLargest(outputs[i]);
expectedOutputPhones[i] = SequenceUtils.indexOfLargest(expectedOutputs[i]);
}
int errorCount = SequenceUtils.calculateEditDistance(outputPhones, expectedOutputPhones);
return (float)(errorCount * 100.0 / inferenceCount);
}
}