blob: 19ed8d07394ee393320d3a6b4a02d5ae90cd1b3c [file] [log] [blame]
/*
* Copyright (C) 2021 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.internal.graphics.palette;
import android.annotation.NonNull;
import android.annotation.Nullable;
import android.util.Log;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
/**
* A color quantizer based on the Kmeans algorithm. Prefer using QuantizerCelebi.
*
* This is an implementation of Kmeans based on Celebi's 2011 paper,
* "Improving the Performance of K-Means for Color Quantization". In the paper, this algorithm is
* referred to as "WSMeans", or, "Weighted Square Means" The main advantages of this Kmeans
* implementation are taking advantage of triangle properties to avoid distance calculations, as
* well as indexing colors by their count, thus minimizing the number of points to move around.
*
* Celebi's paper also stabilizes results and guarantees high quality by using starting centroids
* from Wu's quantization algorithm. See QuantizerCelebi for more info.
*/
public final class WSMeansQuantizer implements Quantizer {
private static final String TAG = "QuantizerWsmeans";
private static final boolean DEBUG = false;
private static final int MAX_ITERATIONS = 10;
// Points won't be moved to a closer cluster, if the closer cluster is within
// this distance. 3.0 used because L*a*b* delta E < 3 is considered imperceptible.
private static final float MIN_MOVEMENT_DISTANCE = 3.0f;
private final PointProvider mPointProvider;
private @Nullable Map<Integer, Integer> mInputPixelToCount;
private float[][] mClusters;
private int[] mClusterPopulations;
private float[][] mPoints;
private int[] mPixels;
private int[] mClusterIndices;
private int[][] mIndexMatrix = {};
private float[][] mDistanceMatrix = {};
private Palette mPalette;
public WSMeansQuantizer(int[] inClusters, PointProvider pointProvider,
@Nullable Map<Integer, Integer> inputPixelToCount) {
mPointProvider = pointProvider;
mClusters = new float[inClusters.length][3];
int index = 0;
for (int cluster : inClusters) {
float[] point = pointProvider.fromInt(cluster);
mClusters[index++] = point;
}
mInputPixelToCount = inputPixelToCount;
}
@Override
public List<Palette.Swatch> getQuantizedColors() {
return mPalette.getSwatches();
}
@Override
public void quantize(@NonNull int[] pixels, int maxColors) {
assert (pixels.length > 0);
if (mInputPixelToCount == null) {
QuantizerMap mapQuantizer = new QuantizerMap();
mapQuantizer.quantize(pixels, maxColors);
mInputPixelToCount = mapQuantizer.getColorToCount();
}
mPoints = new float[mInputPixelToCount.size()][3];
mPixels = new int[mInputPixelToCount.size()];
int index = 0;
for (int pixel : mInputPixelToCount.keySet()) {
mPixels[index] = pixel;
mPoints[index] = mPointProvider.fromInt(pixel);
index++;
}
if (mClusters.length > 0) {
// This implies that the constructor was provided starting clusters. If that was the
// case, we limit the number of clusters to the number of starting clusters and don't
// initialize random clusters.
maxColors = Math.min(maxColors, mClusters.length);
}
maxColors = Math.min(maxColors, mPoints.length);
initializeClusters(maxColors);
for (int i = 0; i < MAX_ITERATIONS; i++) {
calculateClusterDistances(maxColors);
if (!reassignPoints(maxColors)) {
break;
}
recalculateClusterCenters(maxColors);
}
List<Palette.Swatch> swatches = new ArrayList<>();
for (int i = 0; i < maxColors; i++) {
float[] cluster = mClusters[i];
int colorInt = mPointProvider.toInt(cluster);
swatches.add(new Palette.Swatch(colorInt, mClusterPopulations[i]));
}
mPalette = Palette.from(swatches);
}
private void initializeClusters(int maxColors) {
boolean hadInputClusters = mClusters.length > 0;
if (!hadInputClusters) {
int additionalClustersNeeded = maxColors - mClusters.length;
if (DEBUG) {
Log.d(TAG, "have " + mClusters.length + " clusters, want " + maxColors
+ " results, so need " + additionalClustersNeeded + " additional clusters");
}
Random random = new Random(0x42688);
List<float[]> additionalClusters = new ArrayList<>(additionalClustersNeeded);
Set<Integer> clusterIndicesUsed = new HashSet<>();
for (int i = 0; i < additionalClustersNeeded; i++) {
int index = random.nextInt(mPoints.length);
while (clusterIndicesUsed.contains(index)
&& clusterIndicesUsed.size() < mPoints.length) {
index = random.nextInt(mPoints.length);
}
clusterIndicesUsed.add(index);
additionalClusters.add(mPoints[index]);
}
float[][] newClusters = (float[][]) additionalClusters.toArray();
float[][] clusters = Arrays.copyOf(mClusters, maxColors);
System.arraycopy(newClusters, 0, clusters, clusters.length, newClusters.length);
mClusters = clusters;
}
mClusterIndices = new int[mPixels.length];
mClusterPopulations = new int[mPixels.length];
Random random = new Random(0x42688);
for (int i = 0; i < mPixels.length; i++) {
int clusterIndex = random.nextInt(maxColors);
mClusterIndices[i] = clusterIndex;
mClusterPopulations[i] = mInputPixelToCount.get(mPixels[i]);
}
}
void calculateClusterDistances(int maxColors) {
if (mDistanceMatrix.length != maxColors) {
mDistanceMatrix = new float[maxColors][maxColors];
}
for (int i = 0; i <= maxColors; i++) {
for (int j = i + 1; j < maxColors; j++) {
float distance = mPointProvider.distance(mClusters[i], mClusters[j]);
mDistanceMatrix[j][i] = distance;
mDistanceMatrix[i][j] = distance;
}
}
if (mIndexMatrix.length != maxColors) {
mIndexMatrix = new int[maxColors][maxColors];
}
for (int i = 0; i < maxColors; i++) {
ArrayList<Distance> distances = new ArrayList<>(maxColors);
for (int index = 0; index < maxColors; index++) {
distances.add(new Distance(index, mDistanceMatrix[i][index]));
}
distances.sort(
(a, b) -> Float.compare(a.getDistance(), b.getDistance()));
for (int j = 0; j < maxColors; j++) {
mIndexMatrix[i][j] = distances.get(j).getIndex();
}
}
}
boolean reassignPoints(int maxColors) {
boolean colorMoved = false;
for (int i = 0; i < mPoints.length; i++) {
float[] point = mPoints[i];
int previousClusterIndex = mClusterIndices[i];
float[] previousCluster = mClusters[previousClusterIndex];
float previousDistance = mPointProvider.distance(point, previousCluster);
float minimumDistance = previousDistance;
int newClusterIndex = -1;
for (int j = 1; j < maxColors; j++) {
int t = mIndexMatrix[previousClusterIndex][j];
if (mDistanceMatrix[previousClusterIndex][t] >= 4 * previousDistance) {
// Triangle inequality proves there's can be no closer center.
break;
}
float distance = mPointProvider.distance(point, mClusters[t]);
if (distance < minimumDistance) {
minimumDistance = distance;
newClusterIndex = t;
}
}
if (newClusterIndex != -1) {
float distanceChange = (float)
Math.abs((Math.sqrt(minimumDistance) - Math.sqrt(previousDistance)));
if (distanceChange > MIN_MOVEMENT_DISTANCE) {
colorMoved = true;
mClusterIndices[i] = newClusterIndex;
}
}
}
return colorMoved;
}
void recalculateClusterCenters(int maxColors) {
mClusterPopulations = new int[maxColors];
float[] aSums = new float[maxColors];
float[] bSums = new float[maxColors];
float[] cSums = new float[maxColors];
for (int i = 0; i < mPoints.length; i++) {
int clusterIndex = mClusterIndices[i];
float[] point = mPoints[i];
int pixel = mPixels[i];
int count = mInputPixelToCount.get(pixel);
mClusterPopulations[clusterIndex] += count;
aSums[clusterIndex] += point[0] * count;
bSums[clusterIndex] += point[1] * count;
cSums[clusterIndex] += point[2] * count;
}
for (int i = 0; i < maxColors; i++) {
int count = mClusterPopulations[i];
float aSum = aSums[i];
float bSum = bSums[i];
float cSum = cSums[i];
mClusters[i][0] = aSum / count;
mClusters[i][1] = bSum / count;
mClusters[i][2] = cSum / count;
}
}
private static class Distance {
private final int mIndex;
private final float mDistance;
int getIndex() {
return mIndex;
}
float getDistance() {
return mDistance;
}
Distance(int index, float distance) {
mIndex = index;
mDistance = distance;
}
}
}