blob: 8c82c09aeba2c3b786c34bc590111fdaa6cb9e03 [file] [log] [blame]
package com.android.nn.benchmark.evaluators;
import com.android.nn.benchmark.core.EvaluatorInterface;
import com.android.nn.benchmark.core.InferenceInOut;
import com.android.nn.benchmark.core.InferenceInOutSequence;
import com.android.nn.benchmark.core.InferenceResult;
import com.android.nn.benchmark.core.OutputMeanStdDev;
import com.android.nn.benchmark.util.IOUtils;
import java.util.List;
/**
* Base class for (input/output)sequence-by-sequence evaluation.
*/
public abstract class BaseSequenceEvaluator implements EvaluatorInterface {
private OutputMeanStdDev mOutputMeanStdDev = null;
protected int targetOutputIndex = 0;
public void setOutputMeanStdDev(OutputMeanStdDev outputMeanStdDev) {
mOutputMeanStdDev = outputMeanStdDev;
}
@Override
public void EvaluateAccuracy(
List<InferenceInOutSequence> inferenceInOuts, List<InferenceResult> inferenceResults,
List<String> outKeys, List<Float> outValues,
List<String> outValidationErrors) {
if (inferenceInOuts.isEmpty()) {
throw new IllegalArgumentException("Empty inputs/outputs");
}
int dataSize = inferenceInOuts.get(0).mDatasize;
int outputSize = inferenceInOuts.get(0).get(0).mExpectedOutputs[targetOutputIndex].length
/ dataSize;
int sequenceIndex = 0;
int inferenceIndex = 0;
while (inferenceIndex < inferenceResults.size()) {
int sequenceLength = inferenceInOuts.get(sequenceIndex % inferenceInOuts.size()).size();
float[][] outputs = new float[sequenceLength][outputSize];
float[][] expectedOutputs = new float[sequenceLength][outputSize];
for (int i = 0; i < sequenceLength; ++i, ++inferenceIndex) {
InferenceResult result = inferenceResults.get(inferenceIndex);
if (mOutputMeanStdDev != null) {
System.arraycopy(
mOutputMeanStdDev.denormalize(
IOUtils.readFloats(result.mInferenceOutput[targetOutputIndex],
dataSize)), 0,
outputs[i], 0, outputSize);
} else {
System.arraycopy(
IOUtils.readFloats(result.mInferenceOutput[targetOutputIndex],
dataSize), 0,
outputs[i], 0, outputSize);
}
InferenceInOut inOut = inferenceInOuts.get(result.mInputOutputSequenceIndex)
.get(result.mInputOutputIndex);
if (mOutputMeanStdDev != null) {
System.arraycopy(
mOutputMeanStdDev.denormalize(
IOUtils.readFloats(inOut.mExpectedOutputs[targetOutputIndex],
dataSize)), 0,
expectedOutputs[i], 0, outputSize);
} else {
System.arraycopy(
IOUtils.readFloats(inOut.mExpectedOutputs[targetOutputIndex], dataSize),
0,
expectedOutputs[i], 0, outputSize);
}
}
EvaluateSequenceAccuracy(outputs, expectedOutputs, outValidationErrors);
++sequenceIndex;
}
AddValidationResult(outKeys, outValues);
}
protected abstract void EvaluateSequenceAccuracy(float[][] outputs, float[][] expectedOutputs,
List<String> outValidationErrors);
protected abstract void AddValidationResult(List<String> keys, List<Float> values);
}