source: src/main/java/agents/anac/y2019/harddealer/math3/stat/clustering/KMeansPlusPlusClusterer.java

Last change on this file was 204, checked in by Katsuhide Fujita, 5 years ago

Fixed errors of ANAC2019 agents

  • Property svn:executable set to *
File size: 19.4 KB
Line 
1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package agents.anac.y2019.harddealer.math3.stat.clustering;
19
20import java.util.ArrayList;
21import java.util.Collection;
22import java.util.Collections;
23import java.util.List;
24import java.util.Random;
25
26import agents.anac.y2019.harddealer.math3.exception.ConvergenceException;
27import agents.anac.y2019.harddealer.math3.exception.MathIllegalArgumentException;
28import agents.anac.y2019.harddealer.math3.exception.NumberIsTooSmallException;
29import agents.anac.y2019.harddealer.math3.exception.util.LocalizedFormats;
30import agents.anac.y2019.harddealer.math3.stat.descriptive.moment.Variance;
31import agents.anac.y2019.harddealer.math3.util.MathUtils;
32
33/**
34 * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
35 * @param <T> type of the points to cluster
36 * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
37 * @since 2.0
38 * @deprecated As of 3.2 (to be removed in 4.0),
39 * use {@link agents.anac.y2019.harddealer.math3.ml.clustering.KMeansPlusPlusClusterer} instead
40 */
41@Deprecated
42public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
43
44 /** Strategies to use for replacing an empty cluster. */
45 public enum EmptyClusterStrategy {
46
47 /** Split the cluster with largest distance variance. */
48 LARGEST_VARIANCE,
49
50 /** Split the cluster with largest number of points. */
51 LARGEST_POINTS_NUMBER,
52
53 /** Create a cluster around the point farthest from its centroid. */
54 FARTHEST_POINT,
55
56 /** Generate an error. */
57 ERROR
58
59 }
60
61 /** Random generator for choosing initial centers. */
62 private final Random random;
63
64 /** Selected strategy for empty clusters. */
65 private final EmptyClusterStrategy emptyStrategy;
66
67 /** Build a clusterer.
68 * <p>
69 * The default strategy for handling empty clusters that may appear during
70 * algorithm iterations is to split the cluster with largest distance variance.
71 * </p>
72 * @param random random generator to use for choosing initial centers
73 */
74 public KMeansPlusPlusClusterer(final Random random) {
75 this(random, EmptyClusterStrategy.LARGEST_VARIANCE);
76 }
77
78 /** Build a clusterer.
79 * @param random random generator to use for choosing initial centers
80 * @param emptyStrategy strategy to use for handling empty clusters that
81 * may appear during algorithm iterations
82 * @since 2.2
83 */
84 public KMeansPlusPlusClusterer(final Random random, final EmptyClusterStrategy emptyStrategy) {
85 this.random = random;
86 this.emptyStrategy = emptyStrategy;
87 }
88
89 /**
90 * Runs the K-means++ clustering algorithm.
91 *
92 * @param points the points to cluster
93 * @param k the number of clusters to split the data into
94 * @param numTrials number of trial runs
95 * @param maxIterationsPerTrial the maximum number of iterations to run the algorithm
96 * for at each trial run. If negative, no maximum will be used
97 * @return a list of clusters containing the points
98 * @throws MathIllegalArgumentException if the data points are null or the number
99 * of clusters is larger than the number of data points
100 * @throws ConvergenceException if an empty cluster is encountered and the
101 * {@link #emptyStrategy} is set to {@code ERROR}
102 */
103 public List<Cluster<T>> cluster(final Collection<T> points, final int k,
104 int numTrials, int maxIterationsPerTrial)
105 throws MathIllegalArgumentException, ConvergenceException {
106
107 // at first, we have not found any clusters list yet
108 List<Cluster<T>> best = null;
109 double bestVarianceSum = Double.POSITIVE_INFINITY;
110
111 // do several clustering trials
112 for (int i = 0; i < numTrials; ++i) {
113
114 // compute a clusters list
115 List<Cluster<T>> clusters = cluster(points, k, maxIterationsPerTrial);
116
117 // compute the variance of the current list
118 double varianceSum = 0.0;
119 for (final Cluster<T> cluster : clusters) {
120 if (!cluster.getPoints().isEmpty()) {
121
122 // compute the distance variance of the current cluster
123 final T center = cluster.getCenter();
124 final Variance stat = new Variance();
125 for (final T point : cluster.getPoints()) {
126 stat.increment(point.distanceFrom(center));
127 }
128 varianceSum += stat.getResult();
129
130 }
131 }
132
133 if (varianceSum <= bestVarianceSum) {
134 // this one is the best we have found so far, remember it
135 best = clusters;
136 bestVarianceSum = varianceSum;
137 }
138
139 }
140
141 // return the best clusters list found
142 return best;
143
144 }
145
146 /**
147 * Runs the K-means++ clustering algorithm.
148 *
149 * @param points the points to cluster
150 * @param k the number of clusters to split the data into
151 * @param maxIterations the maximum number of iterations to run the algorithm
152 * for. If negative, no maximum will be used
153 * @return a list of clusters containing the points
154 * @throws MathIllegalArgumentException if the data points are null or the number
155 * of clusters is larger than the number of data points
156 * @throws ConvergenceException if an empty cluster is encountered and the
157 * {@link #emptyStrategy} is set to {@code ERROR}
158 */
159 public List<Cluster<T>> cluster(final Collection<T> points, final int k,
160 final int maxIterations)
161 throws MathIllegalArgumentException, ConvergenceException {
162
163 // sanity checks
164 MathUtils.checkNotNull(points);
165
166 // number of clusters has to be smaller or equal the number of data points
167 if (points.size() < k) {
168 throw new NumberIsTooSmallException(points.size(), k, false);
169 }
170
171 // create the initial clusters
172 List<Cluster<T>> clusters = chooseInitialCenters(points, k, random);
173
174 // create an array containing the latest assignment of a point to a cluster
175 // no need to initialize the array, as it will be filled with the first assignment
176 int[] assignments = new int[points.size()];
177 assignPointsToClusters(clusters, points, assignments);
178
179 // iterate through updating the centers until we're done
180 final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
181 for (int count = 0; count < max; count++) {
182 boolean emptyCluster = false;
183 List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>();
184 for (final Cluster<T> cluster : clusters) {
185 final T newCenter;
186 if (cluster.getPoints().isEmpty()) {
187 switch (emptyStrategy) {
188 case LARGEST_VARIANCE :
189 newCenter = getPointFromLargestVarianceCluster(clusters);
190 break;
191 case LARGEST_POINTS_NUMBER :
192 newCenter = getPointFromLargestNumberCluster(clusters);
193 break;
194 case FARTHEST_POINT :
195 newCenter = getFarthestPoint(clusters);
196 break;
197 default :
198 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
199 }
200 emptyCluster = true;
201 } else {
202 newCenter = cluster.getCenter().centroidOf(cluster.getPoints());
203 }
204 newClusters.add(new Cluster<T>(newCenter));
205 }
206 int changes = assignPointsToClusters(newClusters, points, assignments);
207 clusters = newClusters;
208
209 // if there were no more changes in the point-to-cluster assignment
210 // and there are no empty clusters left, return the current clusters
211 if (changes == 0 && !emptyCluster) {
212 return clusters;
213 }
214 }
215 return clusters;
216 }
217
218 /**
219 * Adds the given points to the closest {@link Cluster}.
220 *
221 * @param <T> type of the points to cluster
222 * @param clusters the {@link Cluster}s to add the points to
223 * @param points the points to add to the given {@link Cluster}s
224 * @param assignments points assignments to clusters
225 * @return the number of points assigned to different clusters as the iteration before
226 */
227 private static <T extends Clusterable<T>> int
228 assignPointsToClusters(final List<Cluster<T>> clusters, final Collection<T> points,
229 final int[] assignments) {
230 int assignedDifferently = 0;
231 int pointIndex = 0;
232 for (final T p : points) {
233 int clusterIndex = getNearestCluster(clusters, p);
234 if (clusterIndex != assignments[pointIndex]) {
235 assignedDifferently++;
236 }
237
238 Cluster<T> cluster = clusters.get(clusterIndex);
239 cluster.addPoint(p);
240 assignments[pointIndex++] = clusterIndex;
241 }
242
243 return assignedDifferently;
244 }
245
246 /**
247 * Use K-means++ to choose the initial centers.
248 *
249 * @param <T> type of the points to cluster
250 * @param points the points to choose the initial centers from
251 * @param k the number of centers to choose
252 * @param random random generator to use
253 * @return the initial centers
254 */
255 private static <T extends Clusterable<T>> List<Cluster<T>>
256 chooseInitialCenters(final Collection<T> points, final int k, final Random random) {
257
258 // Convert to list for indexed access. Make it unmodifiable, since removal of items
259 // would screw up the logic of this method.
260 final List<T> pointList = Collections.unmodifiableList(new ArrayList<T> (points));
261
262 // The number of points in the list.
263 final int numPoints = pointList.size();
264
265 // Set the corresponding element in this array to indicate when
266 // elements of pointList are no longer available.
267 final boolean[] taken = new boolean[numPoints];
268
269 // The resulting list of initial centers.
270 final List<Cluster<T>> resultSet = new ArrayList<Cluster<T>>();
271
272 // Choose one center uniformly at random from among the data points.
273 final int firstPointIndex = random.nextInt(numPoints);
274
275 final T firstPoint = pointList.get(firstPointIndex);
276
277 resultSet.add(new Cluster<T>(firstPoint));
278
279 // Must mark it as taken
280 taken[firstPointIndex] = true;
281
282 // To keep track of the minimum distance squared of elements of
283 // pointList to elements of resultSet.
284 final double[] minDistSquared = new double[numPoints];
285
286 // Initialize the elements. Since the only point in resultSet is firstPoint,
287 // this is very easy.
288 for (int i = 0; i < numPoints; i++) {
289 if (i != firstPointIndex) { // That point isn't considered
290 double d = firstPoint.distanceFrom(pointList.get(i));
291 minDistSquared[i] = d*d;
292 }
293 }
294
295 while (resultSet.size() < k) {
296
297 // Sum up the squared distances for the points in pointList not
298 // already taken.
299 double distSqSum = 0.0;
300
301 for (int i = 0; i < numPoints; i++) {
302 if (!taken[i]) {
303 distSqSum += minDistSquared[i];
304 }
305 }
306
307 // Add one new data point as a center. Each point x is chosen with
308 // probability proportional to D(x)2
309 final double r = random.nextDouble() * distSqSum;
310
311 // The index of the next point to be added to the resultSet.
312 int nextPointIndex = -1;
313
314 // Sum through the squared min distances again, stopping when
315 // sum >= r.
316 double sum = 0.0;
317 for (int i = 0; i < numPoints; i++) {
318 if (!taken[i]) {
319 sum += minDistSquared[i];
320 if (sum >= r) {
321 nextPointIndex = i;
322 break;
323 }
324 }
325 }
326
327 // If it's not set to >= 0, the point wasn't found in the previous
328 // for loop, probably because distances are extremely small. Just pick
329 // the last available point.
330 if (nextPointIndex == -1) {
331 for (int i = numPoints - 1; i >= 0; i--) {
332 if (!taken[i]) {
333 nextPointIndex = i;
334 break;
335 }
336 }
337 }
338
339 // We found one.
340 if (nextPointIndex >= 0) {
341
342 final T p = pointList.get(nextPointIndex);
343
344 resultSet.add(new Cluster<T> (p));
345
346 // Mark it as taken.
347 taken[nextPointIndex] = true;
348
349 if (resultSet.size() < k) {
350 // Now update elements of minDistSquared. We only have to compute
351 // the distance to the new center to do this.
352 for (int j = 0; j < numPoints; j++) {
353 // Only have to worry about the points still not taken.
354 if (!taken[j]) {
355 double d = p.distanceFrom(pointList.get(j));
356 double d2 = d * d;
357 if (d2 < minDistSquared[j]) {
358 minDistSquared[j] = d2;
359 }
360 }
361 }
362 }
363
364 } else {
365 // None found --
366 // Break from the while loop to prevent
367 // an infinite loop.
368 break;
369 }
370 }
371
372 return resultSet;
373 }
374
375 /**
376 * Get a random point from the {@link Cluster} with the largest distance variance.
377 *
378 * @param clusters the {@link Cluster}s to search
379 * @return a random point from the selected cluster
380 * @throws ConvergenceException if clusters are all empty
381 */
382 private T getPointFromLargestVarianceCluster(final Collection<Cluster<T>> clusters)
383 throws ConvergenceException {
384
385 double maxVariance = Double.NEGATIVE_INFINITY;
386 Cluster<T> selected = null;
387 for (final Cluster<T> cluster : clusters) {
388 if (!cluster.getPoints().isEmpty()) {
389
390 // compute the distance variance of the current cluster
391 final T center = cluster.getCenter();
392 final Variance stat = new Variance();
393 for (final T point : cluster.getPoints()) {
394 stat.increment(point.distanceFrom(center));
395 }
396 final double variance = stat.getResult();
397
398 // select the cluster with the largest variance
399 if (variance > maxVariance) {
400 maxVariance = variance;
401 selected = cluster;
402 }
403
404 }
405 }
406
407 // did we find at least one non-empty cluster ?
408 if (selected == null) {
409 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
410 }
411
412 // extract a random point from the cluster
413 final List<T> selectedPoints = selected.getPoints();
414 return selectedPoints.remove(random.nextInt(selectedPoints.size()));
415
416 }
417
418 /**
419 * Get a random point from the {@link Cluster} with the largest number of points
420 *
421 * @param clusters the {@link Cluster}s to search
422 * @return a random point from the selected cluster
423 * @throws ConvergenceException if clusters are all empty
424 */
425 private T getPointFromLargestNumberCluster(final Collection<Cluster<T>> clusters) throws ConvergenceException {
426
427 int maxNumber = 0;
428 Cluster<T> selected = null;
429 for (final Cluster<T> cluster : clusters) {
430
431 // get the number of points of the current cluster
432 final int number = cluster.getPoints().size();
433
434 // select the cluster with the largest number of points
435 if (number > maxNumber) {
436 maxNumber = number;
437 selected = cluster;
438 }
439
440 }
441
442 // did we find at least one non-empty cluster ?
443 if (selected == null) {
444 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
445 }
446
447 // extract a random point from the cluster
448 final List<T> selectedPoints = selected.getPoints();
449 return selectedPoints.remove(random.nextInt(selectedPoints.size()));
450
451 }
452
453 /**
454 * Get the point farthest to its cluster center
455 *
456 * @param clusters the {@link Cluster}s to search
457 * @return point farthest to its cluster center
458 * @throws ConvergenceException if clusters are all empty
459 */
460 private T getFarthestPoint(final Collection<Cluster<T>> clusters) throws ConvergenceException {
461
462 double maxDistance = Double.NEGATIVE_INFINITY;
463 Cluster<T> selectedCluster = null;
464 int selectedPoint = -1;
465 for (final Cluster<T> cluster : clusters) {
466
467 // get the farthest point
468 final T center = cluster.getCenter();
469 final List<T> points = cluster.getPoints();
470 for (int i = 0; i < points.size(); ++i) {
471 final double distance = points.get(i).distanceFrom(center);
472 if (distance > maxDistance) {
473 maxDistance = distance;
474 selectedCluster = cluster;
475 selectedPoint = i;
476 }
477 }
478
479 }
480
481 // did we find at least one non-empty cluster ?
482 if (selectedCluster == null) {
483 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
484 }
485
486 return selectedCluster.getPoints().remove(selectedPoint);
487
488 }
489
490 /**
491 * Returns the nearest {@link Cluster} to the given point
492 *
493 * @param <T> type of the points to cluster
494 * @param clusters the {@link Cluster}s to search
495 * @param point the point to find the nearest {@link Cluster} for
496 * @return the index of the nearest {@link Cluster} to the given point
497 */
498 private static <T extends Clusterable<T>> int
499 getNearestCluster(final Collection<Cluster<T>> clusters, final T point) {
500 double minDistance = Double.MAX_VALUE;
501 int clusterIndex = 0;
502 int minCluster = 0;
503 for (final Cluster<T> c : clusters) {
504 final double distance = point.distanceFrom(c.getCenter());
505 if (distance < minDistance) {
506 minDistance = distance;
507 minCluster = clusterIndex;
508 }
509 clusterIndex++;
510 }
511 return minCluster;
512 }
513
514}
Note: See TracBrowser for help on using the repository browser.