Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature mini batch k means2 #120

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
141 changes: 141 additions & 0 deletions src/main/java/org/apache/commons/math4/ml/clustering/ClusterUtils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.commons.math4.ml.clustering;

import org.apache.commons.math4.exception.ConvergenceException;
import org.apache.commons.math4.exception.util.LocalizedFormats;
import org.apache.commons.math4.ml.distance.DistanceMeasure;
import org.apache.commons.math4.ml.distance.EuclideanDistance;
import org.apache.commons.math4.stat.descriptive.moment.Variance;
import org.apache.commons.rng.UniformRandomProvider;

import java.util.Collection;
import java.util.List;

/**
* Common functions used in clustering
*/
public class ClusterUtils {
/**
* Use only for static
*/
private ClusterUtils() {
}

public static final DistanceMeasure DEFAULT_MEASURE = new EuclideanDistance();

/**
* Predict which cluster is best for the point
*
* @param clusters cluster to predict into
* @param point point to predict
* @param measure distance measurer
* @param <T> type of cluster point
* @return the cluster which has nearest center to the point
*/
public static <T extends Clusterable> CentroidCluster<T> predict(List<CentroidCluster<T>> clusters, Clusterable point, DistanceMeasure measure) {
double minDistance = Double.POSITIVE_INFINITY;
CentroidCluster<T> nearestCluster = null;
for (CentroidCluster<T> cluster : clusters) {
double distance = measure.compute(point.getPoint(), cluster.getCenter().getPoint());
if (distance < minDistance) {
minDistance = distance;
nearestCluster = cluster;
}
}
return nearestCluster;
}

/**
* Predict which cluster is best for the point
*
* @param clusters cluster to predict into
* @param point point to predict
* @param <T> type of cluster point
* @return the cluster which has nearest center to the point
*/
public static <T extends Clusterable> CentroidCluster<T> predict(List<CentroidCluster<T>> clusters, Clusterable point) {
return predict(clusters, point, DEFAULT_MEASURE);
}

/**
* Computes the centroid for a set of points.
*
* @param points the set of points
* @param dimension the point dimension
* @return the computed centroid for the set of points
*/
public static <T extends Clusterable> Clusterable centroidOf(final Collection<T> points, final int dimension) {
final double[] centroid = new double[dimension];
for (final T p : points) {
final double[] point = p.getPoint();
for (int i = 0; i < centroid.length; i++) {
centroid[i] += point[i];
}
}
for (int i = 0; i < centroid.length; i++) {
centroid[i] /= points.size();
}
return new DoublePoint(centroid);
}


/**
* Get a random point from the {@link Cluster} with the largest distance variance.
*
* @param clusters the {@link Cluster}s to search
* @param measure DistanceMeasure
* @param random Random generator
* @return a random point from the selected cluster
* @throws ConvergenceException if clusters are all empty
*/
public static <T extends Clusterable> T getPointFromLargestVarianceCluster(final Collection<CentroidCluster<T>> clusters,
final DistanceMeasure measure,
final UniformRandomProvider random)
throws ConvergenceException {
double maxVariance = Double.NEGATIVE_INFINITY;
Cluster<T> selected = null;
for (final CentroidCluster<T> cluster : clusters) {
if (!cluster.getPoints().isEmpty()) {
// compute the distance variance of the current cluster
final Clusterable center = cluster.getCenter();
final Variance stat = new Variance();
for (final T point : cluster.getPoints()) {
stat.increment(measure.compute(point.getPoint(), center.getPoint()));
}
final double variance = stat.getResult();

// select the cluster with the largest variance
if (variance > maxVariance) {
maxVariance = variance;
selected = cluster;
}

}
}

// did we find at least one non-empty cluster ?
if (selected == null) {
throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
}

// extract a random point from the cluster
final List<T> selectedPoints = selected.getPoints();
return selectedPoints.remove(random.nextInt(selectedPoints.size()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import org.apache.commons.math4.exception.MathIllegalArgumentException;
import org.apache.commons.math4.exception.NumberIsTooSmallException;
import org.apache.commons.math4.exception.util.LocalizedFormats;
import org.apache.commons.math4.ml.clustering.initialization.CentroidInitializer;
import org.apache.commons.math4.ml.clustering.initialization.KMeansPlusPlusCentroidInitializer;
import org.apache.commons.math4.ml.distance.DistanceMeasure;
import org.apache.commons.math4.ml.distance.EuclideanDistance;
import org.apache.commons.rng.simple.RandomSource;
Expand Down Expand Up @@ -70,6 +72,9 @@ public enum EmptyClusterStrategy {
/** Selected strategy for empty clusters. */
private final EmptyClusterStrategy emptyStrategy;

/** Centroid initial algorithm. */
private final CentroidInitializer centroidInitializer;

/** Build a clusterer.
* <p>
* The default strategy for handling empty clusters that may appear during
Expand Down Expand Up @@ -148,6 +153,8 @@ public KMeansPlusPlusClusterer(final int k, final int maxIterations,
this.maxIterations = maxIterations;
this.random = random;
this.emptyStrategy = emptyStrategy;
// For KMeansPlusPlusClusterer the centroidInitializer is KMeans++ algorithm.
this.centroidInitializer = new KMeansPlusPlusCentroidInitializer(measure,random);
}

/**
Expand Down Expand Up @@ -205,7 +212,7 @@ public List<CentroidCluster<T>> cluster(final Collection<T> points)
}

// create the initial clusters
List<CentroidCluster<T>> clusters = chooseInitialCenters(points);
List<CentroidCluster<T>> clusters = centroidInitializer.selectCentroids(points, k);

// create an array containing the latest assignment of a point to a cluster
// no need to initialize the array, as it will be filled with the first assignment
Expand Down Expand Up @@ -235,7 +242,7 @@ public List<CentroidCluster<T>> cluster(final Collection<T> points)
}
emptyCluster = true;
} else {
newCenter = centroidOf(cluster.getPoints(), cluster.getCenter().getPoint().length);
newCenter = ClusterUtils.centroidOf(cluster.getPoints(), cluster.getCenter().getPoint().length);
}
newClusters.add(new CentroidCluster<T>(newCenter));
}
Expand Down Expand Up @@ -278,131 +285,6 @@ private int assignPointsToClusters(final List<CentroidCluster<T>> clusters,
return assignedDifferently;
}

/**
* Use K-means++ to choose the initial centers.
*
* @param points the points to choose the initial centers from
* @return the initial centers
*/
private List<CentroidCluster<T>> chooseInitialCenters(final Collection<T> points) {

// Convert to list for indexed access. Make it unmodifiable, since removal of items
// would screw up the logic of this method.
final List<T> pointList = Collections.unmodifiableList(new ArrayList<> (points));

// The number of points in the list.
final int numPoints = pointList.size();

// Set the corresponding element in this array to indicate when
// elements of pointList are no longer available.
final boolean[] taken = new boolean[numPoints];

// The resulting list of initial centers.
final List<CentroidCluster<T>> resultSet = new ArrayList<>();

// Choose one center uniformly at random from among the data points.
final int firstPointIndex = random.nextInt(numPoints);

final T firstPoint = pointList.get(firstPointIndex);

resultSet.add(new CentroidCluster<T>(firstPoint));

// Must mark it as taken
taken[firstPointIndex] = true;

// To keep track of the minimum distance squared of elements of
// pointList to elements of resultSet.
final double[] minDistSquared = new double[numPoints];

// Initialize the elements. Since the only point in resultSet is firstPoint,
// this is very easy.
for (int i = 0; i < numPoints; i++) {
if (i != firstPointIndex) { // That point isn't considered
double d = distance(firstPoint, pointList.get(i));
minDistSquared[i] = d*d;
}
}

while (resultSet.size() < k) {

// Sum up the squared distances for the points in pointList not
// already taken.
double distSqSum = 0.0;

for (int i = 0; i < numPoints; i++) {
if (!taken[i]) {
distSqSum += minDistSquared[i];
}
}

// Add one new data point as a center. Each point x is chosen with
// probability proportional to D(x)2
final double r = random.nextDouble() * distSqSum;

// The index of the next point to be added to the resultSet.
int nextPointIndex = -1;

// Sum through the squared min distances again, stopping when
// sum >= r.
double sum = 0.0;
for (int i = 0; i < numPoints; i++) {
if (!taken[i]) {
sum += minDistSquared[i];
if (sum >= r) {
nextPointIndex = i;
break;
}
}
}

// If it's not set to >= 0, the point wasn't found in the previous
// for loop, probably because distances are extremely small. Just pick
// the last available point.
if (nextPointIndex == -1) {
for (int i = numPoints - 1; i >= 0; i--) {
if (!taken[i]) {
nextPointIndex = i;
break;
}
}
}

// We found one.
if (nextPointIndex >= 0) {

final T p = pointList.get(nextPointIndex);

resultSet.add(new CentroidCluster<T> (p));

// Mark it as taken.
taken[nextPointIndex] = true;

if (resultSet.size() < k) {
// Now update elements of minDistSquared. We only have to compute
// the distance to the new center to do this.
for (int j = 0; j < numPoints; j++) {
// Only have to worry about the points still not taken.
if (!taken[j]) {
double d = distance(p, pointList.get(j));
double d2 = d * d;
if (d2 < minDistSquared[j]) {
minDistSquared[j] = d2;
}
}
}
}

} else {
// None found --
// Break from the while loop to prevent
// an infinite loop.
break;
}
}

return resultSet;
}

/**
* Get a random point from the {@link Cluster} with the largest distance variance.
*
Expand Down Expand Up @@ -540,26 +422,4 @@ private int getNearestCluster(final Collection<CentroidCluster<T>> clusters, fin
}
return minCluster;
}

/**
* Computes the centroid for a set of points.
*
* @param points the set of points
* @param dimension the point dimension
* @return the computed centroid for the set of points
*/
private Clusterable centroidOf(final Collection<T> points, final int dimension) {
final double[] centroid = new double[dimension];
for (final T p : points) {
final double[] point = p.getPoint();
for (int i = 0; i < centroid.length; i++) {
centroid[i] += point[i];
}
}
for (int i = 0; i < centroid.length; i++) {
centroid[i] /= points.size();
}
return new DoublePoint(centroid);
}

}
Loading