1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.math4.legacy.distribution.fitting;
18
19 import java.util.ArrayList;
20 import java.util.Arrays;
21 import java.util.List;
22
23 import org.apache.commons.math4.legacy.distribution.MixtureMultivariateNormalDistribution;
24 import org.apache.commons.math4.legacy.distribution.MultivariateNormalDistribution;
25 import org.apache.commons.math4.legacy.exception.ConvergenceException;
26 import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
27 import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
28 import org.apache.commons.math4.legacy.exception.NumberIsTooLargeException;
29 import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
30 import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
31 import org.apache.commons.math4.legacy.linear.Array2DRowRealMatrix;
32 import org.apache.commons.math4.legacy.linear.RealMatrix;
33 import org.apache.commons.math4.legacy.linear.SingularMatrixException;
34 import org.apache.commons.math4.legacy.stat.correlation.Covariance;
35 import org.apache.commons.math4.core.jdkmath.JdkMath;
36 import org.apache.commons.math4.legacy.core.MathArrays;
37 import org.apache.commons.math4.legacy.core.Pair;
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54 public class MultivariateNormalMixtureExpectationMaximization {
55
56
57
58 private static final int DEFAULT_MAX_ITERATIONS = 1000;
59
60
61
62 private static final double DEFAULT_THRESHOLD = 1E-5;
63
64
65
66 private final double[][] data;
67
68
69
70 private MixtureMultivariateNormalDistribution fittedModel;
71
72
73
74 private double logLikelihood;
75
76
77
78
79
80
81
82
83
84
85
86 public MultivariateNormalMixtureExpectationMaximization(double[][] data)
87 throws NotStrictlyPositiveException,
88 DimensionMismatchException,
89 NumberIsTooSmallException {
90 if (data.length < 1) {
91 throw new NotStrictlyPositiveException(data.length);
92 }
93
94 this.data = new double[data.length][data[0].length];
95
96 for (int i = 0; i < data.length; i++) {
97 if (data[i].length != data[0].length) {
98
99 throw new DimensionMismatchException(data[i].length,
100 data[0].length);
101 }
102 if (data[i].length < 1) {
103 throw new NumberIsTooSmallException(LocalizedFormats.NUMBER_TOO_SMALL,
104 data[i].length, 1, true);
105 }
106 this.data[i] = Arrays.copyOf(data[i], data[i].length);
107 }
108 }
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132 public void fit(final MixtureMultivariateNormalDistribution initialMixture,
133 final int maxIterations,
134 final double threshold)
135 throws SingularMatrixException,
136 NotStrictlyPositiveException,
137 DimensionMismatchException {
138 if (maxIterations < 1) {
139 throw new NotStrictlyPositiveException(maxIterations);
140 }
141
142 if (threshold < Double.MIN_VALUE) {
143 throw new NotStrictlyPositiveException(threshold);
144 }
145
146 final int n = data.length;
147
148
149
150 final int numCols = data[0].length;
151 final int k = initialMixture.getComponents().size();
152
153 final int numMeanColumns
154 = initialMixture.getComponents().get(0).getSecond().getMeans().length;
155
156 if (numMeanColumns != numCols) {
157 throw new DimensionMismatchException(numMeanColumns, numCols);
158 }
159
160 int numIterations = 0;
161 double previousLogLikelihood = 0d;
162
163 logLikelihood = Double.NEGATIVE_INFINITY;
164
165
166 fittedModel = new MixtureMultivariateNormalDistribution(initialMixture.getComponents());
167
168 while (numIterations++ <= maxIterations &&
169 JdkMath.abs(previousLogLikelihood - logLikelihood) > threshold) {
170 previousLogLikelihood = logLikelihood;
171 double sumLogLikelihood = 0d;
172
173
174 final List<Pair<Double, MultivariateNormalDistribution>> components
175 = fittedModel.getComponents();
176
177
178 final double[] weights = new double[k];
179
180 final MultivariateNormalDistribution[] mvns = new MultivariateNormalDistribution[k];
181
182 for (int j = 0; j < k; j++) {
183 weights[j] = components.get(j).getFirst();
184 mvns[j] = components.get(j).getSecond();
185 }
186
187
188
189
190
191 final double[][] gamma = new double[n][k];
192
193
194 final double[] gammaSums = new double[k];
195
196
197 final double[][] gammaDataProdSums = new double[k][numCols];
198
199 for (int i = 0; i < n; i++) {
200 final double rowDensity = fittedModel.density(data[i]);
201 sumLogLikelihood += JdkMath.log(rowDensity);
202
203 for (int j = 0; j < k; j++) {
204 gamma[i][j] = weights[j] * mvns[j].density(data[i]) / rowDensity;
205 gammaSums[j] += gamma[i][j];
206
207 for (int col = 0; col < numCols; col++) {
208 gammaDataProdSums[j][col] += gamma[i][j] * data[i][col];
209 }
210 }
211 }
212
213 logLikelihood = sumLogLikelihood / n;
214
215
216
217 final double[] newWeights = new double[k];
218 final double[][] newMeans = new double[k][numCols];
219
220 for (int j = 0; j < k; j++) {
221 newWeights[j] = gammaSums[j] / n;
222 for (int col = 0; col < numCols; col++) {
223 newMeans[j][col] = gammaDataProdSums[j][col] / gammaSums[j];
224 }
225 }
226
227
228 final RealMatrix[] newCovMats = new RealMatrix[k];
229 for (int j = 0; j < k; j++) {
230 newCovMats[j] = new Array2DRowRealMatrix(numCols, numCols);
231 }
232 for (int i = 0; i < n; i++) {
233 for (int j = 0; j < k; j++) {
234 final RealMatrix vec
235 = new Array2DRowRealMatrix(MathArrays.ebeSubtract(data[i], newMeans[j]));
236 final RealMatrix dataCov
237 = vec.multiply(vec.transpose()).scalarMultiply(gamma[i][j]);
238 newCovMats[j] = newCovMats[j].add(dataCov);
239 }
240 }
241
242
243 final double[][][] newCovMatArrays = new double[k][numCols][numCols];
244 for (int j = 0; j < k; j++) {
245 newCovMats[j] = newCovMats[j].scalarMultiply(1d / gammaSums[j]);
246 newCovMatArrays[j] = newCovMats[j].getData();
247 }
248
249
250 fittedModel = new MixtureMultivariateNormalDistribution(newWeights,
251 newMeans,
252 newCovMatArrays);
253 }
254
255 if (JdkMath.abs(previousLogLikelihood - logLikelihood) > threshold) {
256
257 throw new ConvergenceException();
258 }
259 }
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278 public void fit(MixtureMultivariateNormalDistribution initialMixture)
279 throws SingularMatrixException,
280 NotStrictlyPositiveException {
281 fit(initialMixture, DEFAULT_MAX_ITERATIONS, DEFAULT_THRESHOLD);
282 }
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302 public static MixtureMultivariateNormalDistribution estimate(final double[][] data,
303 final int numComponents)
304 throws NotStrictlyPositiveException,
305 DimensionMismatchException {
306 if (data.length < 2) {
307 throw new NotStrictlyPositiveException(data.length);
308 }
309 if (numComponents < 1) {
310 throw new NumberIsTooSmallException(numComponents, 1, true);
311 }
312 if (numComponents > data.length) {
313 throw new NumberIsTooLargeException(numComponents, data.length, true);
314 }
315
316 final int numRows = data.length;
317 final int numCols = data[0].length;
318
319
320 final DataRow[] sortedData = new DataRow[numRows];
321 for (int i = 0; i < numRows; i++) {
322 sortedData[i] = new DataRow(data[i]);
323 }
324 Arrays.sort(sortedData);
325
326
327 final double weight = 1d / numComponents;
328
329
330 final List<Pair<Double, MultivariateNormalDistribution>> components =
331 new ArrayList<>(numComponents);
332
333
334 for (int binIndex = 0; binIndex < numComponents; binIndex++) {
335
336 final int minIndex = (binIndex * numRows) / numComponents;
337
338
339 final int maxIndex = ((binIndex + 1) * numRows) / numComponents;
340
341
342 final int numBinRows = maxIndex - minIndex;
343
344
345 final double[][] binData = new double[numBinRows][numCols];
346
347
348 final double[] columnMeans = new double[numCols];
349
350
351 for (int i = minIndex, iBin = 0; i < maxIndex; i++, iBin++) {
352 for (int j = 0; j < numCols; j++) {
353 final double val = sortedData[i].getRow()[j];
354 columnMeans[j] += val;
355 binData[iBin][j] = val;
356 }
357 }
358
359 MathArrays.scaleInPlace(1d / numBinRows, columnMeans);
360
361
362 final double[][] covMat
363 = new Covariance(binData).getCovarianceMatrix().getData();
364 final MultivariateNormalDistribution mvn
365 = new MultivariateNormalDistribution(columnMeans, covMat);
366
367 components.add(new Pair<>(weight, mvn));
368 }
369
370 return new MixtureMultivariateNormalDistribution(components);
371 }
372
373
374
375
376
377
378 public double getLogLikelihood() {
379 return logLikelihood;
380 }
381
382
383
384
385
386
387 public MixtureMultivariateNormalDistribution getFittedModel() {
388 return new MixtureMultivariateNormalDistribution(fittedModel.getComponents());
389 }
390
391
392
393
394 private static final class DataRow implements Comparable<DataRow> {
395
396 private final double[] row;
397
398 private Double mean;
399
400
401
402
403
404 DataRow(final double[] data) {
405
406 row = data;
407
408 mean = 0d;
409 for (int i = 0; i < data.length; i++) {
410 mean += data[i];
411 }
412 mean /= data.length;
413 }
414
415
416
417
418
419
420 @Override
421 public int compareTo(final DataRow other) {
422 return mean.compareTo(other.mean);
423 }
424
425
426 @Override
427 public boolean equals(Object other) {
428
429 if (this == other) {
430 return true;
431 }
432
433 if (other instanceof DataRow) {
434 return MathArrays.equals(row, ((DataRow) other).row);
435 }
436
437 return false;
438 }
439
440
441 @Override
442 public int hashCode() {
443 return Arrays.hashCode(row);
444 }
445
446
447
448
449 public double[] getRow() {
450 return row;
451 }
452 }
453 }
454