blob: 53c0227d033f2b621c80624bc4338dc74af6a920 [file] [log] [blame]
/*
* Copyright (C) 2023 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.media.videoquality.bdrate;
import com.google.common.annotations.VisibleForTesting;
import org.apache.commons.math3.analysis.interpolation.AkimaSplineInterpolator;
import org.apache.commons.math3.analysis.polynomials.PolynomialFunction;
import org.apache.commons.math3.analysis.polynomials.PolynomialSplineFunction;
import org.apache.commons.math3.stat.descriptive.moment.Mean;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.logging.Logger;
/**
* Calculator for the Bjontegaard-Delta rate between two rate-distortion curves for an arbitrary
* metric.
*
* <p>Bjontegaard's metric allows to compute the average gain in PSNR or the average percent saving
* in bitrate between two rate-distortion curves [1]. The code is an implementation of Bjontegaard
* metric to calculate average bit-rate saving.
*
* <p>1. G. Bjontegaard, Calculation of average PSNR differences between RD-curves (VCEG-M33) 2. S.
* Pateux, J. Jung, An excel add-in for computing Bjontegaard metric and its evolution 3. VCEG-M34.
* http://wftp3.itu.int/av-arch/video-site/0104_Aus/VCEG-M34.xls
*/
public class BdRateCalculator {
private static final Logger LOGGER = Logger.getLogger(BdRateCalculator.class.getName());
private static final Mean MEAN = new Mean();
private BdRateCalculator() {}
public static BdRateCalculator create() {
return new BdRateCalculator();
}
/**
* Calculates the Bjontegaard-Delta (BD) rate for the two provided rate-distortion curves.
*
* @return The Bjontegaard-Delta rate value, or Double.NaN if it could not be calculated.
* @throws IllegalArgumentException if any of the input data is invalid in rate-distortion
* context (e.g. bitrate < 0).
*/
public double calculate(RateDistortionCurve referenceCurve, RateDistortionCurve targetCurve) {
RateDistortionCurve clusteredReferenceCurve = cluster(referenceCurve);
RateDistortionCurve clusteredTargetCurve = cluster(targetCurve);
LOGGER.fine("Checking BD-RATE preconditions.");
if (clusteredReferenceCurve.points().size() < 5
|| clusteredTargetCurve.points().size() < 5) {
throw new BdRateCalculationFailedPreconditionException(
"Not enough data points in the supplied rate-distortion curves.");
}
if (!isMonotonicallyIncreasing(clusteredReferenceCurve)
|| !isMonotonicallyIncreasing(clusteredTargetCurve)) {
throw new BdRateCalculationFailedPreconditionException(
"The supplied rate-distortion curves were not monotonically increasing.");
}
CalculationParameters referenceCalcParams = curveToCalculationParameters(referenceCurve);
CalculationParameters targetCalcParams = curveToCalculationParameters(targetCurve);
if (referenceCalcParams.mMaxDistortion < targetCalcParams.mMinDistortion
|| targetCalcParams.mMaxDistortion < referenceCalcParams.mMinDistortion) {
throw new BdRateCalculationFailedPreconditionException(
"The supplied rate-distortion curves do not overlap.");
}
LOGGER.fine("Preconditions passed, calculating BD-RATE.");
AkimaSplineInterpolator akimaInterpolator = new AkimaSplineInterpolator();
PolynomialSplineFunction referenceFitCurve =
akimaInterpolator.interpolate(
referenceCalcParams.mDistortions, referenceCalcParams.mLogBitrates);
PolynomialSplineFunction targetFitCurve =
akimaInterpolator.interpolate(
targetCalcParams.mDistortions, targetCalcParams.mLogBitrates);
double integrationRangeMin =
Math.max(referenceCalcParams.mMinDistortion, targetCalcParams.mMinDistortion);
double integrationRangeMax =
Math.min(referenceCalcParams.mMaxDistortion, targetCalcParams.mMaxDistortion);
double referenceAuc =
calculateAuc(referenceFitCurve, integrationRangeMin, integrationRangeMax);
double targetAuc = calculateAuc(targetFitCurve, integrationRangeMin, integrationRangeMax);
double bdRateLog = (targetAuc - referenceAuc) / (integrationRangeMax - integrationRangeMin);
return Math.pow(10, bdRateLog) - 1;
}
/**
* Calculates the area under the curve for the provided {@link PolynomialSplineFunction} between
* the min and max values.
*/
private static double calculateAuc(PolynomialSplineFunction func, double min, double max) {
// Create the integral functions for each of the segments of the spline.
PolynomialFunction[] segmentFuncs = func.getPolynomials();
PolynomialFunction[] integralFuncs = new PolynomialFunction[segmentFuncs.length];
for (int funcIdx = 0; funcIdx < segmentFuncs.length; funcIdx++) {
integralFuncs[funcIdx] = integratePolynomial(segmentFuncs[funcIdx]);
}
// Calculate the integral for each segment, summing up the results
// which is the value of the spline's integral.
double result = 0;
double[] knots = func.getKnots();
for (int leftKnotIdx = 0; leftKnotIdx < knots.length - 1; leftKnotIdx++) {
double leftKnot = knots[leftKnotIdx];
double rightKnot = knots[leftKnotIdx + 1];
if (rightKnot < min) {
continue;
}
if (leftKnot > max) {
break;
}
double integrationLeft = Math.max(0, min - leftKnot);
double integrationRight = Math.min(rightKnot - leftKnot, max - leftKnot);
PolynomialFunction integralFunc = integralFuncs[leftKnotIdx];
result += integralFunc.value(integrationRight) - integralFunc.value(integrationLeft);
}
return result;
}
/**
* Perform a standard polynomial integration by parts on the provided {@link
* PolynomialFunction}, returning a new {@link PolynomialFunction} representing the integrated
* function.
*/
private static PolynomialFunction integratePolynomial(PolynomialFunction function) {
double[] newCoeffs = new double[function.getCoefficients().length + 1];
for (int i = 1; i <= function.getCoefficients().length; i++) {
newCoeffs[i] = function.getCoefficients()[i - 1] / i;
}
newCoeffs[0] = 0;
return new PolynomialFunction(newCoeffs);
}
/**
* Clusters provided rate-distortion points together to reduce noise when the points are close
* together in terms of bitrate.
*
* <p>"Clusters" are points that have a bitrate that is within 1% of the previous
* rate-distortion point. Such points are bucketed and then averaged to provide a single point
* in the same range as the cluster.
*/
@VisibleForTesting
static RateDistortionCurve cluster(RateDistortionCurve baseCurve) {
if (baseCurve.points().size() < 3) {
return baseCurve;
}
RateDistortionCurve.Builder newCurve = RateDistortionCurve.builder();
LinkedList<ArrayList<RateDistortionPoint>> buckets = new LinkedList<>();
// Bucket the items, moving through the points pairwise.
buckets.add(new ArrayList<>());
buckets.peekLast().add(baseCurve.points().first());
Iterator<RateDistortionPoint> pointIterator = baseCurve.points().iterator();
RateDistortionPoint lastPoint = pointIterator.next();
RateDistortionPoint currentPoint;
while (pointIterator.hasNext()) {
currentPoint = pointIterator.next();
if (currentPoint.rate() / lastPoint.rate() > 1.01) {
buckets.add(new ArrayList<>());
}
buckets.peekLast().add(currentPoint);
lastPoint = currentPoint;
}
for (ArrayList<RateDistortionPoint> bucket : buckets) {
if (bucket.size() < 2) {
newCurve.addPoint(bucket.get(0));
}
// For a bucket with multiple points, the new point is the average
// between all other points.
newCurve.addPoint(
RateDistortionPoint.create(
MEAN.evaluate(bucket.stream().mapToDouble(p -> p.rate()).toArray()),
MEAN.evaluate(
bucket.stream().mapToDouble(p -> p.distortion()).toArray())));
}
return newCurve.build();
}
/**
* Returns whether a {@link RateDistortionCurve} is monotonically increasing which is required
* for the Cubic Spline interpolation performed during BD rate calculation.
*/
private static boolean isMonotonicallyIncreasing(RateDistortionCurve rateDistortionCurve) {
Iterator<RateDistortionPoint> pointIterator = rateDistortionCurve.points().iterator();
RateDistortionPoint lastPoint = pointIterator.next();
RateDistortionPoint currentPoint;
while (pointIterator.hasNext()) {
currentPoint = pointIterator.next();
if (currentPoint.distortion() <= lastPoint.distortion()) {
return false;
}
lastPoint = currentPoint;
}
return true;
}
/**
* Extracts the points in a {@link RateDistortionCurve} into {@link CalculationParameters} which
* is a format friendlier for calculation.
*/
private static CalculationParameters curveToCalculationParameters(
RateDistortionCurve rateDistortionCurve) {
CalculationParameters params = new CalculationParameters();
params.mLogBitrates = new double[rateDistortionCurve.points().size()];
params.mDistortions = new double[rateDistortionCurve.points().size()];
int i = 0;
for (RateDistortionPoint p : rateDistortionCurve.points()) {
params.mLogBitrates[i] = Math.log10(p.rate());
params.mDistortions[i] = p.distortion();
i++;
}
// Since the values are guaranteed sorted in a rate-distortion curve,
// min/max is just the ends of the data.
params.mMinDistortion = params.mDistortions[0];
params.mMaxDistortion = params.mDistortions[params.mDistortions.length - 1];
return params;
}
/** Internal-only dataclass for storing the parameters needed for calculating BD rate. */
private static class CalculationParameters {
private double[] mLogBitrates;
private double[] mDistortions;
private double mMinDistortion;
private double mMaxDistortion;
}
}