View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math4.legacy.distribution.fitting;
18  
19  import java.util.ArrayList;
20  import java.util.Arrays;
21  import java.util.List;
22  
23  import org.apache.commons.math4.legacy.distribution.MixtureMultivariateNormalDistribution;
24  import org.apache.commons.math4.legacy.distribution.MultivariateNormalDistribution;
25  import org.apache.commons.math4.legacy.exception.ConvergenceException;
26  import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
27  import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
28  import org.apache.commons.math4.legacy.exception.NumberIsTooLargeException;
29  import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
30  import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
31  import org.apache.commons.math4.legacy.linear.Array2DRowRealMatrix;
32  import org.apache.commons.math4.legacy.linear.RealMatrix;
33  import org.apache.commons.math4.legacy.linear.SingularMatrixException;
34  import org.apache.commons.math4.legacy.stat.correlation.Covariance;
35  import org.apache.commons.math4.core.jdkmath.JdkMath;
36  import org.apache.commons.math4.legacy.core.MathArrays;
37  import org.apache.commons.math4.legacy.core.Pair;
38  
39  /**
40   * Expectation-Maximization algorithm for fitting the parameters of
41   * multivariate normal mixture model distributions.
42   *
43   * This implementation is pure original code based on <a
44   * href="https://www.ee.washington.edu/techsite/papers/documents/UWEETR-2010-0002.pdf">
45   * EM Demystified: An Expectation-Maximization Tutorial</a> by Yihua Chen and Maya R. Gupta,
46   * Department of Electrical Engineering, University of Washington, Seattle, WA 98195.
47   * It was verified using external tools like <a
48   * href="http://cran.r-project.org/web/packages/mixtools/index.html">CRAN Mixtools</a>
49   * (see the JUnit test cases) but it is <strong>not</strong> based on Mixtools code at all.
50   * The discussion of the origin of this class can be seen in the comments of the <a
51   * href="https://issues.apache.org/jira/browse/MATH-817">MATH-817</a> JIRA issue.
52   * @since 3.2
53   */
54  public class MultivariateNormalMixtureExpectationMaximization {
55      /**
56       * Default maximum number of iterations allowed per fitting process.
57       */
58      private static final int DEFAULT_MAX_ITERATIONS = 1000;
59      /**
60       * Default convergence threshold for fitting.
61       */
62      private static final double DEFAULT_THRESHOLD = 1E-5;
63      /**
64       * The data to fit.
65       */
66      private final double[][] data;
67      /**
68       * The model fit against the data.
69       */
70      private MixtureMultivariateNormalDistribution fittedModel;
71      /**
72       * The log likelihood of the data given the fitted model.
73       */
74      private double logLikelihood;
75  
76      /**
77       * Creates an object to fit a multivariate normal mixture model to data.
78       *
79       * @param data Data to use in fitting procedure
80       * @throws NotStrictlyPositiveException if data has no rows
81       * @throws DimensionMismatchException if rows of data have different numbers
82       *             of columns
83       * @throws NumberIsTooSmallException if the number of columns in the data is
84       *             less than 1
85       */
86      public MultivariateNormalMixtureExpectationMaximization(double[][] data)
87          throws NotStrictlyPositiveException,
88                 DimensionMismatchException,
89                 NumberIsTooSmallException {
90          if (data.length < 1) {
91              throw new NotStrictlyPositiveException(data.length);
92          }
93  
94          this.data = new double[data.length][data[0].length];
95  
96          for (int i = 0; i < data.length; i++) {
97              if (data[i].length != data[0].length) {
98                  // Jagged arrays not allowed
99                  throw new DimensionMismatchException(data[i].length,
100                                                      data[0].length);
101             }
102             if (data[i].length < 1) {
103                 throw new NumberIsTooSmallException(LocalizedFormats.NUMBER_TOO_SMALL,
104                                                     data[i].length, 1, true);
105             }
106             this.data[i] = Arrays.copyOf(data[i], data[i].length);
107         }
108     }
109 
110     /**
111      * Fit a mixture model to the data supplied to the constructor.
112      *
113      * The quality of the fit depends on the concavity of the data provided to
114      * the constructor and the initial mixture provided to this function. If the
115      * data has many local optima, multiple runs of the fitting function with
116      * different initial mixtures may be required to find the optimal solution.
117      * If a SingularMatrixException is encountered, it is possible that another
118      * initialization would work.
119      *
120      * @param initialMixture Model containing initial values of weights and
121      *            multivariate normals
122      * @param maxIterations Maximum iterations allowed for fit
123      * @param threshold Convergence threshold computed as difference in
124      *             logLikelihoods between successive iterations
125      * @throws SingularMatrixException if any component's covariance matrix is
126      *             singular during fitting
127      * @throws NotStrictlyPositiveException if numComponents is less than one
128      *             or threshold is less than Double.MIN_VALUE
129      * @throws DimensionMismatchException if initialMixture mean vector and data
130      *             number of columns are not equal
131      */
132     public void fit(final MixtureMultivariateNormalDistribution initialMixture,
133                     final int maxIterations,
134                     final double threshold)
135             throws SingularMatrixException,
136                    NotStrictlyPositiveException,
137                    DimensionMismatchException {
138         if (maxIterations < 1) {
139             throw new NotStrictlyPositiveException(maxIterations);
140         }
141 
142         if (threshold < Double.MIN_VALUE) {
143             throw new NotStrictlyPositiveException(threshold);
144         }
145 
146         final int n = data.length;
147 
148         // Number of data columns. Jagged data already rejected in constructor,
149         // so we can assume the lengths of each row are equal.
150         final int numCols = data[0].length;
151         final int k = initialMixture.getComponents().size();
152 
153         final int numMeanColumns
154             = initialMixture.getComponents().get(0).getSecond().getMeans().length;
155 
156         if (numMeanColumns != numCols) {
157             throw new DimensionMismatchException(numMeanColumns, numCols);
158         }
159 
160         int numIterations = 0;
161         double previousLogLikelihood = 0d;
162 
163         logLikelihood = Double.NEGATIVE_INFINITY;
164 
165         // Initialize model to fit to initial mixture.
166         fittedModel = new MixtureMultivariateNormalDistribution(initialMixture.getComponents());
167 
168         while (numIterations++ <= maxIterations &&
169                JdkMath.abs(previousLogLikelihood - logLikelihood) > threshold) {
170             previousLogLikelihood = logLikelihood;
171             double sumLogLikelihood = 0d;
172 
173             // Mixture components
174             final List<Pair<Double, MultivariateNormalDistribution>> components
175                 = fittedModel.getComponents();
176 
177             // Weight and distribution of each component
178             final double[] weights = new double[k];
179 
180             final MultivariateNormalDistribution[] mvns = new MultivariateNormalDistribution[k];
181 
182             for (int j = 0; j < k; j++) {
183                 weights[j] = components.get(j).getFirst();
184                 mvns[j] = components.get(j).getSecond();
185             }
186 
187             // E-step: compute the data dependent parameters of the expectation
188             // function.
189             // The percentage of row's total density between a row and a
190             // component
191             final double[][] gamma = new double[n][k];
192 
193             // Sum of gamma for each component
194             final double[] gammaSums = new double[k];
195 
196             // Sum of gamma times its row for each each component
197             final double[][] gammaDataProdSums = new double[k][numCols];
198 
199             for (int i = 0; i < n; i++) {
200                 final double rowDensity = fittedModel.density(data[i]);
201                 sumLogLikelihood += JdkMath.log(rowDensity);
202 
203                 for (int j = 0; j < k; j++) {
204                     gamma[i][j] = weights[j] * mvns[j].density(data[i]) / rowDensity;
205                     gammaSums[j] += gamma[i][j];
206 
207                     for (int col = 0; col < numCols; col++) {
208                         gammaDataProdSums[j][col] += gamma[i][j] * data[i][col];
209                     }
210                 }
211             }
212 
213             logLikelihood = sumLogLikelihood / n;
214 
215             // M-step: compute the new parameters based on the expectation
216             // function.
217             final double[] newWeights = new double[k];
218             final double[][] newMeans = new double[k][numCols];
219 
220             for (int j = 0; j < k; j++) {
221                 newWeights[j] = gammaSums[j] / n;
222                 for (int col = 0; col < numCols; col++) {
223                     newMeans[j][col] = gammaDataProdSums[j][col] / gammaSums[j];
224                 }
225             }
226 
227             // Compute new covariance matrices
228             final RealMatrix[] newCovMats = new RealMatrix[k];
229             for (int j = 0; j < k; j++) {
230                 newCovMats[j] = new Array2DRowRealMatrix(numCols, numCols);
231             }
232             for (int i = 0; i < n; i++) {
233                 for (int j = 0; j < k; j++) {
234                     final RealMatrix vec
235                         = new Array2DRowRealMatrix(MathArrays.ebeSubtract(data[i], newMeans[j]));
236                     final RealMatrix dataCov
237                         = vec.multiply(vec.transpose()).scalarMultiply(gamma[i][j]);
238                     newCovMats[j] = newCovMats[j].add(dataCov);
239                 }
240             }
241 
242             // Converting to arrays for use by fitted model
243             final double[][][] newCovMatArrays = new double[k][numCols][numCols];
244             for (int j = 0; j < k; j++) {
245                 newCovMats[j] = newCovMats[j].scalarMultiply(1d / gammaSums[j]);
246                 newCovMatArrays[j] = newCovMats[j].getData();
247             }
248 
249             // Update current model
250             fittedModel = new MixtureMultivariateNormalDistribution(newWeights,
251                                                                     newMeans,
252                                                                     newCovMatArrays);
253         }
254 
255         if (JdkMath.abs(previousLogLikelihood - logLikelihood) > threshold) {
256             // Did not converge before the maximum number of iterations
257             throw new ConvergenceException();
258         }
259     }
260 
261     /**
262      * Fit a mixture model to the data supplied to the constructor.
263      *
264      * The quality of the fit depends on the concavity of the data provided to
265      * the constructor and the initial mixture provided to this function. If the
266      * data has many local optima, multiple runs of the fitting function with
267      * different initial mixtures may be required to find the optimal solution.
268      * If a SingularMatrixException is encountered, it is possible that another
269      * initialization would work.
270      *
271      * @param initialMixture Model containing initial values of weights and
272      *            multivariate normals
273      * @throws SingularMatrixException if any component's covariance matrix is
274      *             singular during fitting
275      * @throws NotStrictlyPositiveException if numComponents is less than one or
276      *             threshold is less than Double.MIN_VALUE
277      */
278     public void fit(MixtureMultivariateNormalDistribution initialMixture)
279         throws SingularMatrixException,
280                NotStrictlyPositiveException {
281         fit(initialMixture, DEFAULT_MAX_ITERATIONS, DEFAULT_THRESHOLD);
282     }
283 
284     /**
285      * Helper method to create a multivariate normal mixture model which can be
286      * used to initialize {@link #fit(MixtureMultivariateNormalDistribution)}.
287      *
288      * This method uses the data supplied to the constructor to try to determine
289      * a good mixture model at which to start the fit, but it is not guaranteed
290      * to supply a model which will find the optimal solution or even converge.
291      *
292      * @param data Data to estimate distribution
293      * @param numComponents Number of components for estimated mixture
294      * @return Multivariate normal mixture model estimated from the data
295      * @throws NumberIsTooLargeException if {@code numComponents} is greater
296      * than the number of data rows.
297      * @throws NumberIsTooSmallException if {@code numComponents < 1}.
298      * @throws NotStrictlyPositiveException if data has less than 2 rows
299      * @throws DimensionMismatchException if rows of data have different numbers
300      *             of columns
301      */
302     public static MixtureMultivariateNormalDistribution estimate(final double[][] data,
303                                                                  final int numComponents)
304         throws NotStrictlyPositiveException,
305                DimensionMismatchException {
306         if (data.length < 2) {
307             throw new NotStrictlyPositiveException(data.length);
308         }
309         if (numComponents < 1) {
310             throw new NumberIsTooSmallException(numComponents, 1, true);
311         }
312         if (numComponents > data.length) {
313             throw new NumberIsTooLargeException(numComponents, data.length, true);
314         }
315 
316         final int numRows = data.length;
317         final int numCols = data[0].length;
318 
319         // sort the data
320         final DataRow[] sortedData = new DataRow[numRows];
321         for (int i = 0; i < numRows; i++) {
322             sortedData[i] = new DataRow(data[i]);
323         }
324         Arrays.sort(sortedData);
325 
326         // uniform weight for each bin
327         final double weight = 1d / numComponents;
328 
329         // components of mixture model to be created
330         final List<Pair<Double, MultivariateNormalDistribution>> components =
331                 new ArrayList<>(numComponents);
332 
333         // create a component based on data in each bin
334         for (int binIndex = 0; binIndex < numComponents; binIndex++) {
335             // minimum index (inclusive) from sorted data for this bin
336             final int minIndex = (binIndex * numRows) / numComponents;
337 
338             // maximum index (exclusive) from sorted data for this bin
339             final int maxIndex = ((binIndex + 1) * numRows) / numComponents;
340 
341             // number of data records that will be in this bin
342             final int numBinRows = maxIndex - minIndex;
343 
344             // data for this bin
345             final double[][] binData = new double[numBinRows][numCols];
346 
347             // mean of each column for the data in the this bin
348             final double[] columnMeans = new double[numCols];
349 
350             // populate bin and create component
351             for (int i = minIndex, iBin = 0; i < maxIndex; i++, iBin++) {
352                 for (int j = 0; j < numCols; j++) {
353                     final double val = sortedData[i].getRow()[j];
354                     columnMeans[j] += val;
355                     binData[iBin][j] = val;
356                 }
357             }
358 
359             MathArrays.scaleInPlace(1d / numBinRows, columnMeans);
360 
361             // covariance matrix for this bin
362             final double[][] covMat
363                 = new Covariance(binData).getCovarianceMatrix().getData();
364             final MultivariateNormalDistribution mvn
365                 = new MultivariateNormalDistribution(columnMeans, covMat);
366 
367             components.add(new Pair<>(weight, mvn));
368         }
369 
370         return new MixtureMultivariateNormalDistribution(components);
371     }
372 
373     /**
374      * Gets the log likelihood of the data under the fitted model.
375      *
376      * @return Log likelihood of data or zero of no data has been fit
377      */
378     public double getLogLikelihood() {
379         return logLikelihood;
380     }
381 
382     /**
383      * Gets the fitted model.
384      *
385      * @return fitted model or {@code null} if no fit has been performed yet.
386      */
387     public MixtureMultivariateNormalDistribution getFittedModel() {
388         return new MixtureMultivariateNormalDistribution(fittedModel.getComponents());
389     }
390 
391     /**
392      * Class used for sorting user-supplied data.
393      */
394     private static final class DataRow implements Comparable<DataRow> {
395         /** One data row. */
396         private final double[] row;
397         /** Mean of the data row. */
398         private Double mean;
399 
400         /**
401          * Create a data row.
402          * @param data Data to use for the row
403          */
404         DataRow(final double[] data) {
405             // Store reference.
406             row = data;
407             // Compute mean.
408             mean = 0d;
409             for (int i = 0; i < data.length; i++) {
410                 mean += data[i];
411             }
412             mean /= data.length;
413         }
414 
415         /**
416          * Compare two data rows.
417          * @param other The other row
418          * @return int for sorting
419          */
420         @Override
421         public int compareTo(final DataRow other) {
422             return mean.compareTo(other.mean);
423         }
424 
425         /** {@inheritDoc} */
426         @Override
427         public boolean equals(Object other) {
428 
429             if (this == other) {
430                 return true;
431             }
432 
433             if (other instanceof DataRow) {
434                 return MathArrays.equals(row, ((DataRow) other).row);
435             }
436 
437             return false;
438         }
439 
440         /** {@inheritDoc} */
441         @Override
442         public int hashCode() {
443             return Arrays.hashCode(row);
444         }
445         /**
446          * Get a data row.
447          * @return data row array
448          */
449         public double[] getRow() {
450             return row;
451         }
452     }
453 }
454