source: src/main/java/agents/anac/y2019/harddealer/math3/ml/clustering/FuzzyKMeansClusterer.java

Last change on this file was 204, checked in by Katsuhide Fujita, 5 years ago

Fixed errors of ANAC2019 agents

  • Property svn:executable set to *
File size: 15.2 KB
Line 
1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17package agents.anac.y2019.harddealer.math3.ml.clustering;
18
19import java.util.ArrayList;
20import java.util.Collection;
21import java.util.Collections;
22import java.util.List;
23
24import agents.anac.y2019.harddealer.math3.exception.MathIllegalArgumentException;
25import agents.anac.y2019.harddealer.math3.exception.MathIllegalStateException;
26import agents.anac.y2019.harddealer.math3.exception.NumberIsTooSmallException;
27import agents.anac.y2019.harddealer.math3.linear.MatrixUtils;
28import agents.anac.y2019.harddealer.math3.linear.RealMatrix;
29import agents.anac.y2019.harddealer.math3.ml.distance.DistanceMeasure;
30import agents.anac.y2019.harddealer.math3.ml.distance.EuclideanDistance;
31import agents.anac.y2019.harddealer.math3.random.JDKRandomGenerator;
32import agents.anac.y2019.harddealer.math3.random.RandomGenerator;
33import agents.anac.y2019.harddealer.math3.util.FastMath;
34import agents.anac.y2019.harddealer.math3.util.MathArrays;
35import agents.anac.y2019.harddealer.math3.util.MathUtils;
36
37/**
38 * Fuzzy K-Means clustering algorithm.
39 * <p>
40 * The Fuzzy K-Means algorithm is a variation of the classical K-Means algorithm, with the
41 * major difference that a single data point is not uniquely assigned to a single cluster.
42 * Instead, each point i has a set of weights u<sub>ij</sub> which indicate the degree of membership
43 * to the cluster j.
44 * <p>
45 * The algorithm then tries to minimize the objective function:
46 * <pre>
47 * J = &#8721;<sub>i=1..C</sub>&#8721;<sub>k=1..N</sub> u<sub>ik</sub><sup>m</sup>d<sub>ik</sub><sup>2</sup>
48 * </pre>
49 * with d<sub>ik</sub> being the distance between data point i and the cluster center k.
50 * <p>
51 * The algorithm requires two parameters:
52 * <ul>
53 * <li>k: the number of clusters
54 * <li>fuzziness: determines the level of cluster fuzziness, larger values lead to fuzzier clusters
55 * </ul>
56 * Additional, optional parameters:
57 * <ul>
58 * <li>maxIterations: the maximum number of iterations
59 * <li>epsilon: the convergence criteria, default is 1e-3
60 * </ul>
61 * <p>
62 * The fuzzy variant of the K-Means algorithm is more robust with regard to the selection
63 * of the initial cluster centers.
64 *
65 * @param <T> type of the points to cluster
66 * @since 3.3
67 */
68public class FuzzyKMeansClusterer<T extends Clusterable> extends Clusterer<T> {
69
70 /** The default value for the convergence criteria. */
71 private static final double DEFAULT_EPSILON = 1e-3;
72
73 /** The number of clusters. */
74 private final int k;
75
76 /** The maximum number of iterations. */
77 private final int maxIterations;
78
79 /** The fuzziness factor. */
80 private final double fuzziness;
81
82 /** The convergence criteria. */
83 private final double epsilon;
84
85 /** Random generator for choosing initial centers. */
86 private final RandomGenerator random;
87
88 /** The membership matrix. */
89 private double[][] membershipMatrix;
90
91 /** The list of points used in the last call to {@link #cluster(Collection)}. */
92 private List<T> points;
93
94 /** The list of clusters resulting from the last call to {@link #cluster(Collection)}. */
95 private List<CentroidCluster<T>> clusters;
96
97 /**
98 * Creates a new instance of a FuzzyKMeansClusterer.
99 * <p>
100 * The euclidean distance will be used as default distance measure.
101 *
102 * @param k the number of clusters to split the data into
103 * @param fuzziness the fuzziness factor, must be &gt; 1.0
104 * @throws NumberIsTooSmallException if {@code fuzziness <= 1.0}
105 */
106 public FuzzyKMeansClusterer(final int k, final double fuzziness) throws NumberIsTooSmallException {
107 this(k, fuzziness, -1, new EuclideanDistance());
108 }
109
110 /**
111 * Creates a new instance of a FuzzyKMeansClusterer.
112 *
113 * @param k the number of clusters to split the data into
114 * @param fuzziness the fuzziness factor, must be &gt; 1.0
115 * @param maxIterations the maximum number of iterations to run the algorithm for.
116 * If negative, no maximum will be used.
117 * @param measure the distance measure to use
118 * @throws NumberIsTooSmallException if {@code fuzziness <= 1.0}
119 */
120 public FuzzyKMeansClusterer(final int k, final double fuzziness,
121 final int maxIterations, final DistanceMeasure measure)
122 throws NumberIsTooSmallException {
123 this(k, fuzziness, maxIterations, measure, DEFAULT_EPSILON, new JDKRandomGenerator());
124 }
125
126 /**
127 * Creates a new instance of a FuzzyKMeansClusterer.
128 *
129 * @param k the number of clusters to split the data into
130 * @param fuzziness the fuzziness factor, must be &gt; 1.0
131 * @param maxIterations the maximum number of iterations to run the algorithm for.
132 * If negative, no maximum will be used.
133 * @param measure the distance measure to use
134 * @param epsilon the convergence criteria (default is 1e-3)
135 * @param random random generator to use for choosing initial centers
136 * @throws NumberIsTooSmallException if {@code fuzziness <= 1.0}
137 */
138 public FuzzyKMeansClusterer(final int k, final double fuzziness,
139 final int maxIterations, final DistanceMeasure measure,
140 final double epsilon, final RandomGenerator random)
141 throws NumberIsTooSmallException {
142
143 super(measure);
144
145 if (fuzziness <= 1.0d) {
146 throw new NumberIsTooSmallException(fuzziness, 1.0, false);
147 }
148 this.k = k;
149 this.fuzziness = fuzziness;
150 this.maxIterations = maxIterations;
151 this.epsilon = epsilon;
152 this.random = random;
153
154 this.membershipMatrix = null;
155 this.points = null;
156 this.clusters = null;
157 }
158
159 /**
160 * Return the number of clusters this instance will use.
161 * @return the number of clusters
162 */
163 public int getK() {
164 return k;
165 }
166
167 /**
168 * Returns the fuzziness factor used by this instance.
169 * @return the fuzziness factor
170 */
171 public double getFuzziness() {
172 return fuzziness;
173 }
174
175 /**
176 * Returns the maximum number of iterations this instance will use.
177 * @return the maximum number of iterations, or -1 if no maximum is set
178 */
179 public int getMaxIterations() {
180 return maxIterations;
181 }
182
183 /**
184 * Returns the convergence criteria used by this instance.
185 * @return the convergence criteria
186 */
187 public double getEpsilon() {
188 return epsilon;
189 }
190
191 /**
192 * Returns the random generator this instance will use.
193 * @return the random generator
194 */
195 public RandomGenerator getRandomGenerator() {
196 return random;
197 }
198
199 /**
200 * Returns the {@code nxk} membership matrix, where {@code n} is the number
201 * of data points and {@code k} the number of clusters.
202 * <p>
203 * The element U<sub>i,j</sub> represents the membership value for data point {@code i}
204 * to cluster {@code j}.
205 *
206 * @return the membership matrix
207 * @throws MathIllegalStateException if {@link #cluster(Collection)} has not been called before
208 */
209 public RealMatrix getMembershipMatrix() {
210 if (membershipMatrix == null) {
211 throw new MathIllegalStateException();
212 }
213 return MatrixUtils.createRealMatrix(membershipMatrix);
214 }
215
216 /**
217 * Returns an unmodifiable list of the data points used in the last
218 * call to {@link #cluster(Collection)}.
219 * @return the list of data points, or {@code null} if {@link #cluster(Collection)} has
220 * not been called before.
221 */
222 public List<T> getDataPoints() {
223 return points;
224 }
225
226 /**
227 * Returns the list of clusters resulting from the last call to {@link #cluster(Collection)}.
228 * @return the list of clusters, or {@code null} if {@link #cluster(Collection)} has
229 * not been called before.
230 */
231 public List<CentroidCluster<T>> getClusters() {
232 return clusters;
233 }
234
235 /**
236 * Get the value of the objective function.
237 * @return the objective function evaluation as double value
238 * @throws MathIllegalStateException if {@link #cluster(Collection)} has not been called before
239 */
240 public double getObjectiveFunctionValue() {
241 if (points == null || clusters == null) {
242 throw new MathIllegalStateException();
243 }
244
245 int i = 0;
246 double objFunction = 0.0;
247 for (final T point : points) {
248 int j = 0;
249 for (final CentroidCluster<T> cluster : clusters) {
250 final double dist = distance(point, cluster.getCenter());
251 objFunction += (dist * dist) * FastMath.pow(membershipMatrix[i][j], fuzziness);
252 j++;
253 }
254 i++;
255 }
256 return objFunction;
257 }
258
259 /**
260 * Performs Fuzzy K-Means cluster analysis.
261 *
262 * @param dataPoints the points to cluster
263 * @return the list of clusters
264 * @throws MathIllegalArgumentException if the data points are null or the number
265 * of clusters is larger than the number of data points
266 */
267 @Override
268 public List<CentroidCluster<T>> cluster(final Collection<T> dataPoints)
269 throws MathIllegalArgumentException {
270
271 // sanity checks
272 MathUtils.checkNotNull(dataPoints);
273
274 final int size = dataPoints.size();
275
276 // number of clusters has to be smaller or equal the number of data points
277 if (size < k) {
278 throw new NumberIsTooSmallException(size, k, false);
279 }
280
281 // copy the input collection to an unmodifiable list with indexed access
282 points = Collections.unmodifiableList(new ArrayList<T>(dataPoints));
283 clusters = new ArrayList<CentroidCluster<T>>();
284 membershipMatrix = new double[size][k];
285 final double[][] oldMatrix = new double[size][k];
286
287 // if no points are provided, return an empty list of clusters
288 if (size == 0) {
289 return clusters;
290 }
291
292 initializeMembershipMatrix();
293
294 // there is at least one point
295 final int pointDimension = points.get(0).getPoint().length;
296 for (int i = 0; i < k; i++) {
297 clusters.add(new CentroidCluster<T>(new DoublePoint(new double[pointDimension])));
298 }
299
300 int iteration = 0;
301 final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
302 double difference = 0.0;
303
304 do {
305 saveMembershipMatrix(oldMatrix);
306 updateClusterCenters();
307 updateMembershipMatrix();
308 difference = calculateMaxMembershipChange(oldMatrix);
309 } while (difference > epsilon && ++iteration < max);
310
311 return clusters;
312 }
313
314 /**
315 * Update the cluster centers.
316 */
317 private void updateClusterCenters() {
318 int j = 0;
319 final List<CentroidCluster<T>> newClusters = new ArrayList<CentroidCluster<T>>(k);
320 for (final CentroidCluster<T> cluster : clusters) {
321 final Clusterable center = cluster.getCenter();
322 int i = 0;
323 double[] arr = new double[center.getPoint().length];
324 double sum = 0.0;
325 for (final T point : points) {
326 final double u = FastMath.pow(membershipMatrix[i][j], fuzziness);
327 final double[] pointArr = point.getPoint();
328 for (int idx = 0; idx < arr.length; idx++) {
329 arr[idx] += u * pointArr[idx];
330 }
331 sum += u;
332 i++;
333 }
334 MathArrays.scaleInPlace(1.0 / sum, arr);
335 newClusters.add(new CentroidCluster<T>(new DoublePoint(arr)));
336 j++;
337 }
338 clusters.clear();
339 clusters = newClusters;
340 }
341
342 /**
343 * Updates the membership matrix and assigns the points to the cluster with
344 * the highest membership.
345 */
346 private void updateMembershipMatrix() {
347 for (int i = 0; i < points.size(); i++) {
348 final T point = points.get(i);
349 double maxMembership = Double.MIN_VALUE;
350 int newCluster = -1;
351 for (int j = 0; j < clusters.size(); j++) {
352 double sum = 0.0;
353 final double distA = FastMath.abs(distance(point, clusters.get(j).getCenter()));
354
355 if (distA != 0.0) {
356 for (final CentroidCluster<T> c : clusters) {
357 final double distB = FastMath.abs(distance(point, c.getCenter()));
358 if (distB == 0.0) {
359 sum = Double.POSITIVE_INFINITY;
360 break;
361 }
362 sum += FastMath.pow(distA / distB, 2.0 / (fuzziness - 1.0));
363 }
364 }
365
366 double membership;
367 if (sum == 0.0) {
368 membership = 1.0;
369 } else if (sum == Double.POSITIVE_INFINITY) {
370 membership = 0.0;
371 } else {
372 membership = 1.0 / sum;
373 }
374 membershipMatrix[i][j] = membership;
375
376 if (membershipMatrix[i][j] > maxMembership) {
377 maxMembership = membershipMatrix[i][j];
378 newCluster = j;
379 }
380 }
381 clusters.get(newCluster).addPoint(point);
382 }
383 }
384
385 /**
386 * Initialize the membership matrix with random values.
387 */
388 private void initializeMembershipMatrix() {
389 for (int i = 0; i < points.size(); i++) {
390 for (int j = 0; j < k; j++) {
391 membershipMatrix[i][j] = random.nextDouble();
392 }
393 membershipMatrix[i] = MathArrays.normalizeArray(membershipMatrix[i], 1.0);
394 }
395 }
396
397 /**
398 * Calculate the maximum element-by-element change of the membership matrix
399 * for the current iteration.
400 *
401 * @param matrix the membership matrix of the previous iteration
402 * @return the maximum membership matrix change
403 */
404 private double calculateMaxMembershipChange(final double[][] matrix) {
405 double maxMembership = 0.0;
406 for (int i = 0; i < points.size(); i++) {
407 for (int j = 0; j < clusters.size(); j++) {
408 double v = FastMath.abs(membershipMatrix[i][j] - matrix[i][j]);
409 maxMembership = FastMath.max(v, maxMembership);
410 }
411 }
412 return maxMembership;
413 }
414
415 /**
416 * Copy the membership matrix into the provided matrix.
417 *
418 * @param matrix the place to store the membership matrix
419 */
420 private void saveMembershipMatrix(final double[][] matrix) {
421 for (int i = 0; i < points.size(); i++) {
422 System.arraycopy(membershipMatrix[i], 0, matrix[i], 0, clusters.size());
423 }
424 }
425
426}
Note: See TracBrowser for help on using the repository browser.