source: src/main/java/agents/anac/y2019/harddealer/math3/ml/clustering/KMeansPlusPlusClusterer.java

Last change on this file was 204, checked in by Katsuhide Fujita, 5 years ago

Fixed errors of ANAC2019 agents

  • Property svn:executable set to *
File size: 21.2 KB
Line 
1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package agents.anac.y2019.harddealer.math3.ml.clustering;
19
20import java.util.ArrayList;
21import java.util.Collection;
22import java.util.Collections;
23import java.util.List;
24
25import agents.anac.y2019.harddealer.math3.exception.ConvergenceException;
26import agents.anac.y2019.harddealer.math3.exception.MathIllegalArgumentException;
27import agents.anac.y2019.harddealer.math3.exception.NumberIsTooSmallException;
28import agents.anac.y2019.harddealer.math3.exception.util.LocalizedFormats;
29import agents.anac.y2019.harddealer.math3.ml.distance.DistanceMeasure;
30import agents.anac.y2019.harddealer.math3.ml.distance.EuclideanDistance;
31import agents.anac.y2019.harddealer.math3.random.JDKRandomGenerator;
32import agents.anac.y2019.harddealer.math3.random.RandomGenerator;
33import agents.anac.y2019.harddealer.math3.stat.descriptive.moment.Variance;
34import agents.anac.y2019.harddealer.math3.util.MathUtils;
35
36/**
37 * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
38 * @param <T> type of the points to cluster
39 * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
40 * @since 3.2
41 */
42public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T> {
43
44 /** Strategies to use for replacing an empty cluster. */
45 public enum EmptyClusterStrategy {
46
47 /** Split the cluster with largest distance variance. */
48 LARGEST_VARIANCE,
49
50 /** Split the cluster with largest number of points. */
51 LARGEST_POINTS_NUMBER,
52
53 /** Create a cluster around the point farthest from its centroid. */
54 FARTHEST_POINT,
55
56 /** Generate an error. */
57 ERROR
58
59 }
60
61 /** The number of clusters. */
62 private final int k;
63
64 /** The maximum number of iterations. */
65 private final int maxIterations;
66
67 /** Random generator for choosing initial centers. */
68 private final RandomGenerator random;
69
70 /** Selected strategy for empty clusters. */
71 private final EmptyClusterStrategy emptyStrategy;
72
73 /** Build a clusterer.
74 * <p>
75 * The default strategy for handling empty clusters that may appear during
76 * algorithm iterations is to split the cluster with largest distance variance.
77 * <p>
78 * The euclidean distance will be used as default distance measure.
79 *
80 * @param k the number of clusters to split the data into
81 */
82 public KMeansPlusPlusClusterer(final int k) {
83 this(k, -1);
84 }
85
86 /** Build a clusterer.
87 * <p>
88 * The default strategy for handling empty clusters that may appear during
89 * algorithm iterations is to split the cluster with largest distance variance.
90 * <p>
91 * The euclidean distance will be used as default distance measure.
92 *
93 * @param k the number of clusters to split the data into
94 * @param maxIterations the maximum number of iterations to run the algorithm for.
95 * If negative, no maximum will be used.
96 */
97 public KMeansPlusPlusClusterer(final int k, final int maxIterations) {
98 this(k, maxIterations, new EuclideanDistance());
99 }
100
101 /** Build a clusterer.
102 * <p>
103 * The default strategy for handling empty clusters that may appear during
104 * algorithm iterations is to split the cluster with largest distance variance.
105 *
106 * @param k the number of clusters to split the data into
107 * @param maxIterations the maximum number of iterations to run the algorithm for.
108 * If negative, no maximum will be used.
109 * @param measure the distance measure to use
110 */
111 public KMeansPlusPlusClusterer(final int k, final int maxIterations, final DistanceMeasure measure) {
112 this(k, maxIterations, measure, new JDKRandomGenerator());
113 }
114
115 /** Build a clusterer.
116 * <p>
117 * The default strategy for handling empty clusters that may appear during
118 * algorithm iterations is to split the cluster with largest distance variance.
119 *
120 * @param k the number of clusters to split the data into
121 * @param maxIterations the maximum number of iterations to run the algorithm for.
122 * If negative, no maximum will be used.
123 * @param measure the distance measure to use
124 * @param random random generator to use for choosing initial centers
125 */
126 public KMeansPlusPlusClusterer(final int k, final int maxIterations,
127 final DistanceMeasure measure,
128 final RandomGenerator random) {
129 this(k, maxIterations, measure, random, EmptyClusterStrategy.LARGEST_VARIANCE);
130 }
131
132 /** Build a clusterer.
133 *
134 * @param k the number of clusters to split the data into
135 * @param maxIterations the maximum number of iterations to run the algorithm for.
136 * If negative, no maximum will be used.
137 * @param measure the distance measure to use
138 * @param random random generator to use for choosing initial centers
139 * @param emptyStrategy strategy to use for handling empty clusters that
140 * may appear during algorithm iterations
141 */
142 public KMeansPlusPlusClusterer(final int k, final int maxIterations,
143 final DistanceMeasure measure,
144 final RandomGenerator random,
145 final EmptyClusterStrategy emptyStrategy) {
146 super(measure);
147 this.k = k;
148 this.maxIterations = maxIterations;
149 this.random = random;
150 this.emptyStrategy = emptyStrategy;
151 }
152
153 /**
154 * Return the number of clusters this instance will use.
155 * @return the number of clusters
156 */
157 public int getK() {
158 return k;
159 }
160
161 /**
162 * Returns the maximum number of iterations this instance will use.
163 * @return the maximum number of iterations, or -1 if no maximum is set
164 */
165 public int getMaxIterations() {
166 return maxIterations;
167 }
168
169 /**
170 * Returns the random generator this instance will use.
171 * @return the random generator
172 */
173 public RandomGenerator getRandomGenerator() {
174 return random;
175 }
176
177 /**
178 * Returns the {@link EmptyClusterStrategy} used by this instance.
179 * @return the {@link EmptyClusterStrategy}
180 */
181 public EmptyClusterStrategy getEmptyClusterStrategy() {
182 return emptyStrategy;
183 }
184
185 /**
186 * Runs the K-means++ clustering algorithm.
187 *
188 * @param points the points to cluster
189 * @return a list of clusters containing the points
190 * @throws MathIllegalArgumentException if the data points are null or the number
191 * of clusters is larger than the number of data points
192 * @throws ConvergenceException if an empty cluster is encountered and the
193 * {@link #emptyStrategy} is set to {@code ERROR}
194 */
195 @Override
196 public List<CentroidCluster<T>> cluster(final Collection<T> points)
197 throws MathIllegalArgumentException, ConvergenceException {
198
199 // sanity checks
200 MathUtils.checkNotNull(points);
201
202 // number of clusters has to be smaller or equal the number of data points
203 if (points.size() < k) {
204 throw new NumberIsTooSmallException(points.size(), k, false);
205 }
206
207 // create the initial clusters
208 List<CentroidCluster<T>> clusters = chooseInitialCenters(points);
209
210 // create an array containing the latest assignment of a point to a cluster
211 // no need to initialize the array, as it will be filled with the first assignment
212 int[] assignments = new int[points.size()];
213 assignPointsToClusters(clusters, points, assignments);
214
215 // iterate through updating the centers until we're done
216 final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
217 for (int count = 0; count < max; count++) {
218 boolean emptyCluster = false;
219 List<CentroidCluster<T>> newClusters = new ArrayList<CentroidCluster<T>>();
220 for (final CentroidCluster<T> cluster : clusters) {
221 final Clusterable newCenter;
222 if (cluster.getPoints().isEmpty()) {
223 switch (emptyStrategy) {
224 case LARGEST_VARIANCE :
225 newCenter = getPointFromLargestVarianceCluster(clusters);
226 break;
227 case LARGEST_POINTS_NUMBER :
228 newCenter = getPointFromLargestNumberCluster(clusters);
229 break;
230 case FARTHEST_POINT :
231 newCenter = getFarthestPoint(clusters);
232 break;
233 default :
234 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
235 }
236 emptyCluster = true;
237 } else {
238 newCenter = centroidOf(cluster.getPoints(), cluster.getCenter().getPoint().length);
239 }
240 newClusters.add(new CentroidCluster<T>(newCenter));
241 }
242 int changes = assignPointsToClusters(newClusters, points, assignments);
243 clusters = newClusters;
244
245 // if there were no more changes in the point-to-cluster assignment
246 // and there are no empty clusters left, return the current clusters
247 if (changes == 0 && !emptyCluster) {
248 return clusters;
249 }
250 }
251 return clusters;
252 }
253
254 /**
255 * Adds the given points to the closest {@link Cluster}.
256 *
257 * @param clusters the {@link Cluster}s to add the points to
258 * @param points the points to add to the given {@link Cluster}s
259 * @param assignments points assignments to clusters
260 * @return the number of points assigned to different clusters as the iteration before
261 */
262 private int assignPointsToClusters(final List<CentroidCluster<T>> clusters,
263 final Collection<T> points,
264 final int[] assignments) {
265 int assignedDifferently = 0;
266 int pointIndex = 0;
267 for (final T p : points) {
268 int clusterIndex = getNearestCluster(clusters, p);
269 if (clusterIndex != assignments[pointIndex]) {
270 assignedDifferently++;
271 }
272
273 CentroidCluster<T> cluster = clusters.get(clusterIndex);
274 cluster.addPoint(p);
275 assignments[pointIndex++] = clusterIndex;
276 }
277
278 return assignedDifferently;
279 }
280
281 /**
282 * Use K-means++ to choose the initial centers.
283 *
284 * @param points the points to choose the initial centers from
285 * @return the initial centers
286 */
287 private List<CentroidCluster<T>> chooseInitialCenters(final Collection<T> points) {
288
289 // Convert to list for indexed access. Make it unmodifiable, since removal of items
290 // would screw up the logic of this method.
291 final List<T> pointList = Collections.unmodifiableList(new ArrayList<T> (points));
292
293 // The number of points in the list.
294 final int numPoints = pointList.size();
295
296 // Set the corresponding element in this array to indicate when
297 // elements of pointList are no longer available.
298 final boolean[] taken = new boolean[numPoints];
299
300 // The resulting list of initial centers.
301 final List<CentroidCluster<T>> resultSet = new ArrayList<CentroidCluster<T>>();
302
303 // Choose one center uniformly at random from among the data points.
304 final int firstPointIndex = random.nextInt(numPoints);
305
306 final T firstPoint = pointList.get(firstPointIndex);
307
308 resultSet.add(new CentroidCluster<T>(firstPoint));
309
310 // Must mark it as taken
311 taken[firstPointIndex] = true;
312
313 // To keep track of the minimum distance squared of elements of
314 // pointList to elements of resultSet.
315 final double[] minDistSquared = new double[numPoints];
316
317 // Initialize the elements. Since the only point in resultSet is firstPoint,
318 // this is very easy.
319 for (int i = 0; i < numPoints; i++) {
320 if (i != firstPointIndex) { // That point isn't considered
321 double d = distance(firstPoint, pointList.get(i));
322 minDistSquared[i] = d*d;
323 }
324 }
325
326 while (resultSet.size() < k) {
327
328 // Sum up the squared distances for the points in pointList not
329 // already taken.
330 double distSqSum = 0.0;
331
332 for (int i = 0; i < numPoints; i++) {
333 if (!taken[i]) {
334 distSqSum += minDistSquared[i];
335 }
336 }
337
338 // Add one new data point as a center. Each point x is chosen with
339 // probability proportional to D(x)2
340 final double r = random.nextDouble() * distSqSum;
341
342 // The index of the next point to be added to the resultSet.
343 int nextPointIndex = -1;
344
345 // Sum through the squared min distances again, stopping when
346 // sum >= r.
347 double sum = 0.0;
348 for (int i = 0; i < numPoints; i++) {
349 if (!taken[i]) {
350 sum += minDistSquared[i];
351 if (sum >= r) {
352 nextPointIndex = i;
353 break;
354 }
355 }
356 }
357
358 // If it's not set to >= 0, the point wasn't found in the previous
359 // for loop, probably because distances are extremely small. Just pick
360 // the last available point.
361 if (nextPointIndex == -1) {
362 for (int i = numPoints - 1; i >= 0; i--) {
363 if (!taken[i]) {
364 nextPointIndex = i;
365 break;
366 }
367 }
368 }
369
370 // We found one.
371 if (nextPointIndex >= 0) {
372
373 final T p = pointList.get(nextPointIndex);
374
375 resultSet.add(new CentroidCluster<T> (p));
376
377 // Mark it as taken.
378 taken[nextPointIndex] = true;
379
380 if (resultSet.size() < k) {
381 // Now update elements of minDistSquared. We only have to compute
382 // the distance to the new center to do this.
383 for (int j = 0; j < numPoints; j++) {
384 // Only have to worry about the points still not taken.
385 if (!taken[j]) {
386 double d = distance(p, pointList.get(j));
387 double d2 = d * d;
388 if (d2 < minDistSquared[j]) {
389 minDistSquared[j] = d2;
390 }
391 }
392 }
393 }
394
395 } else {
396 // None found --
397 // Break from the while loop to prevent
398 // an infinite loop.
399 break;
400 }
401 }
402
403 return resultSet;
404 }
405
406 /**
407 * Get a random point from the {@link Cluster} with the largest distance variance.
408 *
409 * @param clusters the {@link Cluster}s to search
410 * @return a random point from the selected cluster
411 * @throws ConvergenceException if clusters are all empty
412 */
413 private T getPointFromLargestVarianceCluster(final Collection<CentroidCluster<T>> clusters)
414 throws ConvergenceException {
415
416 double maxVariance = Double.NEGATIVE_INFINITY;
417 Cluster<T> selected = null;
418 for (final CentroidCluster<T> cluster : clusters) {
419 if (!cluster.getPoints().isEmpty()) {
420
421 // compute the distance variance of the current cluster
422 final Clusterable center = cluster.getCenter();
423 final Variance stat = new Variance();
424 for (final T point : cluster.getPoints()) {
425 stat.increment(distance(point, center));
426 }
427 final double variance = stat.getResult();
428
429 // select the cluster with the largest variance
430 if (variance > maxVariance) {
431 maxVariance = variance;
432 selected = cluster;
433 }
434
435 }
436 }
437
438 // did we find at least one non-empty cluster ?
439 if (selected == null) {
440 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
441 }
442
443 // extract a random point from the cluster
444 final List<T> selectedPoints = selected.getPoints();
445 return selectedPoints.remove(random.nextInt(selectedPoints.size()));
446
447 }
448
449 /**
450 * Get a random point from the {@link Cluster} with the largest number of points
451 *
452 * @param clusters the {@link Cluster}s to search
453 * @return a random point from the selected cluster
454 * @throws ConvergenceException if clusters are all empty
455 */
456 private T getPointFromLargestNumberCluster(final Collection<? extends Cluster<T>> clusters)
457 throws ConvergenceException {
458
459 int maxNumber = 0;
460 Cluster<T> selected = null;
461 for (final Cluster<T> cluster : clusters) {
462
463 // get the number of points of the current cluster
464 final int number = cluster.getPoints().size();
465
466 // select the cluster with the largest number of points
467 if (number > maxNumber) {
468 maxNumber = number;
469 selected = cluster;
470 }
471
472 }
473
474 // did we find at least one non-empty cluster ?
475 if (selected == null) {
476 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
477 }
478
479 // extract a random point from the cluster
480 final List<T> selectedPoints = selected.getPoints();
481 return selectedPoints.remove(random.nextInt(selectedPoints.size()));
482
483 }
484
485 /**
486 * Get the point farthest to its cluster center
487 *
488 * @param clusters the {@link Cluster}s to search
489 * @return point farthest to its cluster center
490 * @throws ConvergenceException if clusters are all empty
491 */
492 private T getFarthestPoint(final Collection<CentroidCluster<T>> clusters) throws ConvergenceException {
493
494 double maxDistance = Double.NEGATIVE_INFINITY;
495 Cluster<T> selectedCluster = null;
496 int selectedPoint = -1;
497 for (final CentroidCluster<T> cluster : clusters) {
498
499 // get the farthest point
500 final Clusterable center = cluster.getCenter();
501 final List<T> points = cluster.getPoints();
502 for (int i = 0; i < points.size(); ++i) {
503 final double distance = distance(points.get(i), center);
504 if (distance > maxDistance) {
505 maxDistance = distance;
506 selectedCluster = cluster;
507 selectedPoint = i;
508 }
509 }
510
511 }
512
513 // did we find at least one non-empty cluster ?
514 if (selectedCluster == null) {
515 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
516 }
517
518 return selectedCluster.getPoints().remove(selectedPoint);
519
520 }
521
522 /**
523 * Returns the nearest {@link Cluster} to the given point
524 *
525 * @param clusters the {@link Cluster}s to search
526 * @param point the point to find the nearest {@link Cluster} for
527 * @return the index of the nearest {@link Cluster} to the given point
528 */
529 private int getNearestCluster(final Collection<CentroidCluster<T>> clusters, final T point) {
530 double minDistance = Double.MAX_VALUE;
531 int clusterIndex = 0;
532 int minCluster = 0;
533 for (final CentroidCluster<T> c : clusters) {
534 final double distance = distance(point, c.getCenter());
535 if (distance < minDistance) {
536 minDistance = distance;
537 minCluster = clusterIndex;
538 }
539 clusterIndex++;
540 }
541 return minCluster;
542 }
543
544 /**
545 * Computes the centroid for a set of points.
546 *
547 * @param points the set of points
548 * @param dimension the point dimension
549 * @return the computed centroid for the set of points
550 */
551 private Clusterable centroidOf(final Collection<T> points, final int dimension) {
552 final double[] centroid = new double[dimension];
553 for (final T p : points) {
554 final double[] point = p.getPoint();
555 for (int i = 0; i < centroid.length; i++) {
556 centroid[i] += point[i];
557 }
558 }
559 for (int i = 0; i < centroid.length; i++) {
560 centroid[i] /= points.size();
561 }
562 return new DoublePoint(centroid);
563 }
564
565}
Note: See TracBrowser for help on using the repository browser.