blob: 4d5b3331e7b10107264fd872f2057fd8d8c229e1 [file] [log] [blame]
/*
* Copyright (C) 2017 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License
*/
package com.android.internal.ml.clustering;
import android.annotation.NonNull;
import android.util.Log;
import com.android.internal.annotations.VisibleForTesting;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
/**
* Simple K-Means implementation
*/
public class KMeans {
private static final boolean DEBUG = false;
private static final String TAG = "KMeans";
private final Random mRandomState;
private final int mMaxIterations;
private float mSqConvergenceEpsilon;
public KMeans() {
this(new Random());
}
public KMeans(Random random) {
this(random, 30 /* maxIterations */, 0.005f /* convergenceEpsilon */);
}
public KMeans(Random random, int maxIterations, float convergenceEpsilon) {
mRandomState = random;
mMaxIterations = maxIterations;
mSqConvergenceEpsilon = convergenceEpsilon * convergenceEpsilon;
}
/**
* Runs k-means on the input data (X) trying to find k means.
*
* K-Means is known for getting stuck into local optima, so you might
* want to run it multiple time and argmax on {@link KMeans#score(List)}
*
* @param k The number of points to return.
* @param inputData Input data.
* @return An array of k Means, each representing a centroid and data points that belong to it.
*/
public List<Mean> predict(final int k, final float[][] inputData) {
checkDataSetSanity(inputData);
int dimension = inputData[0].length;
final ArrayList<Mean> means = new ArrayList<>();
for (int i = 0; i < k; i++) {
Mean m = new Mean(dimension);
for (int j = 0; j < dimension; j++) {
m.mCentroid[j] = mRandomState.nextFloat();
}
means.add(m);
}
// Iterate until we converge or run out of iterations
boolean converged = false;
for (int i = 0; i < mMaxIterations; i++) {
converged = step(means, inputData);
if (converged) {
if (DEBUG) Log.d(TAG, "Converged at iteration: " + i);
break;
}
}
if (!converged && DEBUG) Log.d(TAG, "Did not converge");
return means;
}
/**
* Score calculates the inertia between means.
* This can be considered as an E step of an EM algorithm.
*
* @param means Means to use when calculating score.
* @return The score
*/
public static double score(@NonNull List<Mean> means) {
double score = 0;
final int meansSize = means.size();
for (int i = 0; i < meansSize; i++) {
Mean mean = means.get(i);
for (int j = 0; j < meansSize; j++) {
Mean compareTo = means.get(j);
if (mean == compareTo) {
continue;
}
double distance = Math.sqrt(sqDistance(mean.mCentroid, compareTo.mCentroid));
score += distance;
}
}
return score;
}
@VisibleForTesting
public void checkDataSetSanity(float[][] inputData) {
if (inputData == null) {
throw new IllegalArgumentException("Data set is null.");
} else if (inputData.length == 0) {
throw new IllegalArgumentException("Data set is empty.");
} else if (inputData[0] == null) {
throw new IllegalArgumentException("Bad data set format.");
}
final int dimension = inputData[0].length;
final int length = inputData.length;
for (int i = 1; i < length; i++) {
if (inputData[i] == null || inputData[i].length != dimension) {
throw new IllegalArgumentException("Bad data set format.");
}
}
}
/**
* K-Means iteration.
*
* @param means Current means
* @param inputData Input data
* @return True if data set converged
*/
private boolean step(final ArrayList<Mean> means, final float[][] inputData) {
// Clean up the previous state because we need to compute
// which point belongs to each mean again.
for (int i = means.size() - 1; i >= 0; i--) {
final Mean mean = means.get(i);
mean.mClosestItems.clear();
}
for (int i = inputData.length - 1; i >= 0; i--) {
final float[] current = inputData[i];
final Mean nearest = nearestMean(current, means);
nearest.mClosestItems.add(current);
}
boolean converged = true;
// Move each mean towards the nearest data set points
for (int i = means.size() - 1; i >= 0; i--) {
final Mean mean = means.get(i);
if (mean.mClosestItems.size() == 0) {
continue;
}
// Compute the new mean centroid:
// 1. Sum all all points
// 2. Average them
final float[] oldCentroid = mean.mCentroid;
mean.mCentroid = new float[oldCentroid.length];
for (int j = 0; j < mean.mClosestItems.size(); j++) {
// Update each centroid component
for (int p = 0; p < mean.mCentroid.length; p++) {
mean.mCentroid[p] += mean.mClosestItems.get(j)[p];
}
}
for (int j = 0; j < mean.mCentroid.length; j++) {
mean.mCentroid[j] /= mean.mClosestItems.size();
}
// We converged if the centroid didn't move for any of the means.
if (sqDistance(oldCentroid, mean.mCentroid) > mSqConvergenceEpsilon) {
converged = false;
}
}
return converged;
}
@VisibleForTesting
public static Mean nearestMean(float[] point, List<Mean> means) {
Mean nearest = null;
float nearestDistance = Float.MAX_VALUE;
final int meanCount = means.size();
for (int i = 0; i < meanCount; i++) {
Mean next = means.get(i);
// We don't need the sqrt when comparing distances in euclidean space
// because they exist on both sides of the equation and cancel each other out.
float nextDistance = sqDistance(point, next.mCentroid);
if (nextDistance < nearestDistance) {
nearest = next;
nearestDistance = nextDistance;
}
}
return nearest;
}
@VisibleForTesting
public static float sqDistance(float[] a, float[] b) {
float dist = 0;
final int length = a.length;
for (int i = 0; i < length; i++) {
dist += (a[i] - b[i]) * (a[i] - b[i]);
}
return dist;
}
/**
* Definition of a mean, contains a centroid and points on its cluster.
*/
public static class Mean {
float[] mCentroid;
final ArrayList<float[]> mClosestItems = new ArrayList<>();
public Mean(int dimension) {
mCentroid = new float[dimension];
}
public Mean(float ...centroid) {
mCentroid = centroid;
}
public float[] getCentroid() {
return mCentroid;
}
public List<float[]> getItems() {
return mClosestItems;
}
@Override
public String toString() {
return "Mean(centroid: " + Arrays.toString(mCentroid) + ", size: "
+ mClosestItems.size() + ")";
}
}
}