1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.math4.legacy.ml.clustering;
18
19 import java.util.ArrayList;
20 import java.util.Collection;
21 import java.util.Collections;
22 import java.util.List;
23
24 import org.apache.commons.math4.legacy.exception.NullArgumentException;
25 import org.apache.commons.math4.legacy.exception.MathIllegalStateException;
26 import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
27 import org.apache.commons.math4.legacy.linear.MatrixUtils;
28 import org.apache.commons.math4.legacy.linear.RealMatrix;
29 import org.apache.commons.math4.legacy.ml.distance.DistanceMeasure;
30 import org.apache.commons.math4.legacy.ml.distance.EuclideanDistance;
31 import org.apache.commons.rng.simple.RandomSource;
32 import org.apache.commons.rng.UniformRandomProvider;
33 import org.apache.commons.math4.core.jdkmath.JdkMath;
34 import org.apache.commons.math4.legacy.core.MathArrays;
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67 public class FuzzyKMeansClusterer<T extends Clusterable> extends Clusterer<T> {
68
69
70 private static final double DEFAULT_EPSILON = 1e-3;
71
72
73 private final int k;
74
75
76 private final int maxIterations;
77
78
79 private final double fuzziness;
80
81
82 private final double epsilon;
83
84
85 private final UniformRandomProvider random;
86
87
88 private double[][] membershipMatrix;
89
90
91 private List<T> points;
92
93
94 private List<CentroidCluster<T>> clusters;
95
96
97
98
99
100
101
102
103
104
105 public FuzzyKMeansClusterer(final int k, final double fuzziness) {
106 this(k, fuzziness, -1, new EuclideanDistance());
107 }
108
109
110
111
112
113
114
115
116
117
118
119 public FuzzyKMeansClusterer(final int k, final double fuzziness,
120 final int maxIterations, final DistanceMeasure measure) {
121 this(k, fuzziness, maxIterations, measure, DEFAULT_EPSILON, RandomSource.MT_64.create());
122 }
123
124
125
126
127
128
129
130
131
132
133
134
135
136 public FuzzyKMeansClusterer(final int k, final double fuzziness,
137 final int maxIterations, final DistanceMeasure measure,
138 final double epsilon, final UniformRandomProvider random) {
139 super(measure);
140
141 if (fuzziness <= 1.0d) {
142 throw new NumberIsTooSmallException(fuzziness, 1.0, false);
143 }
144 this.k = k;
145 this.fuzziness = fuzziness;
146 this.maxIterations = maxIterations;
147 this.epsilon = epsilon;
148 this.random = random;
149
150 this.membershipMatrix = null;
151 this.points = null;
152 this.clusters = null;
153 }
154
155
156
157
158
159 public int getK() {
160 return k;
161 }
162
163
164
165
166
167 public double getFuzziness() {
168 return fuzziness;
169 }
170
171
172
173
174
175 public int getMaxIterations() {
176 return maxIterations;
177 }
178
179
180
181
182
183 public double getEpsilon() {
184 return epsilon;
185 }
186
187
188
189
190
191 public UniformRandomProvider getRandomGenerator() {
192 return random;
193 }
194
195
196
197
198
199
200
201
202
203
204
205 public RealMatrix getMembershipMatrix() {
206 if (membershipMatrix == null) {
207 throw new MathIllegalStateException();
208 }
209 return MatrixUtils.createRealMatrix(membershipMatrix);
210 }
211
212
213
214
215
216
217
218 public List<T> getDataPoints() {
219 return points;
220 }
221
222
223
224
225
226
227 public List<CentroidCluster<T>> getClusters() {
228 return clusters;
229 }
230
231
232
233
234
235
236 public double getObjectiveFunctionValue() {
237 if (points == null || clusters == null) {
238 throw new MathIllegalStateException();
239 }
240
241 int i = 0;
242 double objFunction = 0.0;
243 for (final T point : points) {
244 int j = 0;
245 for (final CentroidCluster<T> cluster : clusters) {
246 final double dist = distance(point, cluster.getCenter());
247 objFunction += (dist * dist) * JdkMath.pow(membershipMatrix[i][j], fuzziness);
248 j++;
249 }
250 i++;
251 }
252 return objFunction;
253 }
254
255
256
257
258
259
260
261
262
263
264 @Override
265 public List<CentroidCluster<T>> cluster(final Collection<T> dataPoints) {
266
267 NullArgumentException.check(dataPoints);
268
269 final int size = dataPoints.size();
270
271
272 if (size < k) {
273 throw new NumberIsTooSmallException(size, k, false);
274 }
275
276
277 points = Collections.unmodifiableList(new ArrayList<>(dataPoints));
278 clusters = new ArrayList<>();
279 membershipMatrix = new double[size][k];
280 final double[][] oldMatrix = new double[size][k];
281
282
283 if (size == 0) {
284 return clusters;
285 }
286
287 initializeMembershipMatrix();
288
289
290 final int pointDimension = points.get(0).getPoint().length;
291 for (int i = 0; i < k; i++) {
292 clusters.add(new CentroidCluster<>(new DoublePoint(new double[pointDimension])));
293 }
294
295 int iteration = 0;
296 final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
297 double difference = 0.0;
298
299 do {
300 saveMembershipMatrix(oldMatrix);
301 updateClusterCenters();
302 updateMembershipMatrix();
303 difference = calculateMaxMembershipChange(oldMatrix);
304 } while (difference > epsilon && ++iteration < max);
305
306 return clusters;
307 }
308
309
310
311
312 private void updateClusterCenters() {
313 int j = 0;
314 final List<CentroidCluster<T>> newClusters = new ArrayList<>(k);
315 for (final CentroidCluster<T> cluster : clusters) {
316 final Clusterable center = cluster.getCenter();
317 int i = 0;
318 double[] arr = new double[center.getPoint().length];
319 double sum = 0.0;
320 for (final T point : points) {
321 final double u = JdkMath.pow(membershipMatrix[i][j], fuzziness);
322 final double[] pointArr = point.getPoint();
323 for (int idx = 0; idx < arr.length; idx++) {
324 arr[idx] += u * pointArr[idx];
325 }
326 sum += u;
327 i++;
328 }
329 MathArrays.scaleInPlace(1.0 / sum, arr);
330 newClusters.add(new CentroidCluster<>(new DoublePoint(arr)));
331 j++;
332 }
333 clusters.clear();
334 clusters = newClusters;
335 }
336
337
338
339
340
341 private void updateMembershipMatrix() {
342 for (int i = 0; i < points.size(); i++) {
343 final T point = points.get(i);
344 double maxMembership = Double.MIN_VALUE;
345 int newCluster = -1;
346 for (int j = 0; j < clusters.size(); j++) {
347 double sum = 0.0;
348 final double distA = JdkMath.abs(distance(point, clusters.get(j).getCenter()));
349
350 if (distA != 0.0) {
351 for (final CentroidCluster<T> c : clusters) {
352 final double distB = JdkMath.abs(distance(point, c.getCenter()));
353 if (distB == 0.0) {
354 sum = Double.POSITIVE_INFINITY;
355 break;
356 }
357 sum += JdkMath.pow(distA / distB, 2.0 / (fuzziness - 1.0));
358 }
359 }
360
361 double membership;
362 if (sum == 0.0) {
363 membership = 1.0;
364 } else if (sum == Double.POSITIVE_INFINITY) {
365 membership = 0.0;
366 } else {
367 membership = 1.0 / sum;
368 }
369 membershipMatrix[i][j] = membership;
370
371 if (membershipMatrix[i][j] > maxMembership) {
372 maxMembership = membershipMatrix[i][j];
373 newCluster = j;
374 }
375 }
376 clusters.get(newCluster).addPoint(point);
377 }
378 }
379
380
381
382
383 private void initializeMembershipMatrix() {
384 for (int i = 0; i < points.size(); i++) {
385 for (int j = 0; j < k; j++) {
386 membershipMatrix[i][j] = random.nextDouble();
387 }
388 membershipMatrix[i] = MathArrays.normalizeArray(membershipMatrix[i], 1.0);
389 }
390 }
391
392
393
394
395
396
397
398
399 private double calculateMaxMembershipChange(final double[][] matrix) {
400 double maxMembership = 0.0;
401 for (int i = 0; i < points.size(); i++) {
402 for (int j = 0; j < clusters.size(); j++) {
403 double v = JdkMath.abs(membershipMatrix[i][j] - matrix[i][j]);
404 maxMembership = JdkMath.max(v, maxMembership);
405 }
406 }
407 return maxMembership;
408 }
409
410
411
412
413
414
415 private void saveMembershipMatrix(final double[][] matrix) {
416 for (int i = 0; i < points.size(); i++) {
417 System.arraycopy(membershipMatrix[i], 0, matrix[i], 0, clusters.size());
418 }
419 }
420 }