View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math4.legacy.ml.clustering;
19  
20  import org.apache.commons.math4.legacy.exception.NullArgumentException;
21  import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
22  import org.apache.commons.math4.legacy.ml.distance.DistanceMeasure;
23  import org.apache.commons.math4.legacy.core.Pair;
24  import org.apache.commons.rng.UniformRandomProvider;
25  import org.apache.commons.rng.sampling.ListSampler;
26  
27  import java.util.ArrayList;
28  import java.util.Collection;
29  import java.util.List;
30  
31  /**
32   * Clustering algorithm <a href="https://www.eecs.tufts.edu/~dsculley/papers/fastkmeans.pdf">
33   * based on KMeans</a>.
34   *
35   * @param <T> Type of the points to cluster.
36   */
37  public class MiniBatchKMeansClusterer<T extends Clusterable>
38      extends KMeansPlusPlusClusterer<T> {
39      /** Batch data size in iteration. */
40      private final int batchSize;
41      /** Iteration count of initialize the centers. */
42      private final int initIterations;
43      /** Data size of batch to initialize the centers. */
44      private final int initBatchSize;
45      /** Maximum number of iterations during which no improvement is occuring. */
46      private final int maxNoImprovementTimes;
47  
48  
49      /**
50       * Build a clusterer.
51       *
52       * @param k Number of clusters to split the data into.
53       * @param maxIterations Maximum number of iterations to run the algorithm for all the points,
54       * The actual number of iterationswill be smaller than {@code maxIterations * size / batchSize},
55       * where {@code size} is the number of points to cluster.
56       * Disabled if negative.
57       * @param batchSize Batch size for training iterations.
58       * @param initIterations Number of iterations allowed in order to find out the best initial centers.
59       * @param initBatchSize Batch size for initializing the clusters centers.
60       * A value of {@code 3 * batchSize} should be suitable in most cases.
61       * @param maxNoImprovementTimes Maximum number of iterations during which no improvement is occuring.
62       * A value of 10 is suitable in most cases.
63       * @param measure Distance measure.
64       * @param random Random generator.
65       * @param emptyStrategy Strategy for handling empty clusters that may appear during algorithm iterations.
66       */
67      public MiniBatchKMeansClusterer(final int k,
68                                      final int maxIterations,
69                                      final int batchSize,
70                                      final int initIterations,
71                                      final int initBatchSize,
72                                      final int maxNoImprovementTimes,
73                                      final DistanceMeasure measure,
74                                      final UniformRandomProvider random,
75                                      final EmptyClusterStrategy emptyStrategy) {
76          super(k, maxIterations, measure, random, emptyStrategy);
77  
78          if (batchSize < 1) {
79              throw new NumberIsTooSmallException(batchSize, 1, true);
80          }
81          if (initIterations < 1) {
82              throw new NumberIsTooSmallException(initIterations, 1, true);
83          }
84          if (initBatchSize < 1) {
85              throw new NumberIsTooSmallException(initBatchSize, 1, true);
86          }
87          if (maxNoImprovementTimes < 1) {
88              throw new NumberIsTooSmallException(maxNoImprovementTimes, 1, true);
89          }
90  
91          this.batchSize = batchSize;
92          this.initIterations = initIterations;
93          this.initBatchSize = initBatchSize;
94          this.maxNoImprovementTimes = maxNoImprovementTimes;
95      }
96  
97      /**
98       * Runs the MiniBatch K-means clustering algorithm.
99       *
100      * @param points Points to cluster (cannot be {@code null}).
101      * @return the clusters.
102      * @throws org.apache.commons.math4.legacy.exception.MathIllegalArgumentException
103      * if the number of points is smaller than the number of clusters.
104      */
105     @Override
106     public List<CentroidCluster<T>> cluster(final Collection<T> points) {
107         // Sanity check.
108         NullArgumentException.check(points);
109         if (points.size() < getNumberOfClusters()) {
110             throw new NumberIsTooSmallException(points.size(), getNumberOfClusters(), false);
111         }
112 
113         final int pointSize = points.size();
114         final int batchCount = pointSize / batchSize + (pointSize % batchSize > 0 ? 1 : 0);
115         final int max = getMaxIterations() < 0 ?
116             Integer.MAX_VALUE :
117             getMaxIterations() * batchCount;
118 
119         final List<T> pointList = new ArrayList<>(points);
120         List<CentroidCluster<T>> clusters = initialCenters(pointList);
121 
122         final ImprovementEvaluator evaluator = new ImprovementEvaluator(batchSize,
123                                                                         maxNoImprovementTimes);
124         for (int i = 0; i < max; i++) {
125             clearClustersPoints(clusters);
126             final List<T> batchPoints = ListSampler.sample(getRandomGenerator(), pointList, batchSize);
127             // Training step.
128             final Pair<Double, List<CentroidCluster<T>>> pair = step(batchPoints, clusters);
129             final double squareDistance = pair.getFirst();
130             clusters = pair.getSecond();
131             // Check whether the training can finished early.
132             if (evaluator.converge(squareDistance, pointSize)) {
133                 break;
134             }
135         }
136 
137         // Add every mini batch points to their nearest cluster.
138         clearClustersPoints(clusters);
139         for (final T point : points) {
140             addToNearestCentroidCluster(point, clusters);
141         }
142 
143         return clusters;
144     }
145 
146     /**
147      * Helper method.
148      *
149      * @param clusters Clusters to clear.
150      */
151     private void clearClustersPoints(final List<CentroidCluster<T>> clusters) {
152         for (CentroidCluster<T> cluster : clusters) {
153             cluster.getPoints().clear();
154         }
155     }
156 
157     /**
158      * Mini batch iteration step.
159      *
160      * @param batchPoints Points selected for this batch.
161      * @param clusters Centers of the clusters.
162      * @return the squared distance of all the batch points to the nearest center.
163      */
164     private Pair<Double, List<CentroidCluster<T>>> step(final List<T> batchPoints,
165                                                         final List<CentroidCluster<T>> clusters) {
166         // Add every mini batch points to their nearest cluster.
167         for (final T point : batchPoints) {
168             addToNearestCentroidCluster(point, clusters);
169         }
170         final List<CentroidCluster<T>> newClusters = adjustClustersCenters(clusters);
171         // Add every mini batch points to their nearest cluster again.
172         double squareDistance = 0.0;
173         for (T point : batchPoints) {
174             final double d = addToNearestCentroidCluster(point, newClusters);
175             squareDistance += d * d;
176         }
177 
178         return new Pair<>(squareDistance, newClusters);
179     }
180 
181     /**
182      * Initializes the clusters centers.
183      *
184      * @param points Points used to initialize the centers.
185      * @return clusters with their center initialized.
186      */
187     private List<CentroidCluster<T>> initialCenters(final List<T> points) {
188         final List<T> validPoints = initBatchSize < points.size() ?
189             ListSampler.sample(getRandomGenerator(), points, initBatchSize) :
190             new ArrayList<>(points);
191         double nearestSquareDistance = Double.POSITIVE_INFINITY;
192         List<CentroidCluster<T>> bestCenters = null;
193 
194         for (int i = 0; i < initIterations; i++) {
195             final List<T> initialPoints = (initBatchSize < points.size()) ?
196                 ListSampler.sample(getRandomGenerator(), points, initBatchSize) :
197                 new ArrayList<>(points);
198             final List<CentroidCluster<T>> clusters = chooseInitialCenters(initialPoints);
199             final Pair<Double, List<CentroidCluster<T>>> pair = step(validPoints, clusters);
200             final double squareDistance = pair.getFirst();
201             final List<CentroidCluster<T>> newClusters = pair.getSecond();
202             //Find out a best centers that has the nearest total square distance.
203             if (squareDistance < nearestSquareDistance) {
204                 nearestSquareDistance = squareDistance;
205                 bestCenters = newClusters;
206             }
207         }
208         return bestCenters;
209     }
210 
211     /**
212      * Adds a point to the cluster whose center is closest.
213      *
214      * @param point Point to add.
215      * @param clusters Clusters.
216      * @return the distance between point and the closest center.
217      */
218     private double addToNearestCentroidCluster(final T point,
219                                                final List<CentroidCluster<T>> clusters) {
220         double minDistance = Double.POSITIVE_INFINITY;
221         CentroidCluster<T> closestCentroidCluster = null;
222 
223         // Find cluster closest to the point.
224         for (CentroidCluster<T> centroidCluster : clusters) {
225             final double distance = distance(point, centroidCluster.getCenter());
226             if (distance < minDistance) {
227                 minDistance = distance;
228                 closestCentroidCluster = centroidCluster;
229             }
230         }
231         NullArgumentException.check(closestCentroidCluster);
232         closestCentroidCluster.addPoint(point);
233 
234         return minDistance;
235     }
236 
237     /**
238      * Stopping criterion.
239      * The evaluator checks whether improvement occurred during the
240      * {@link #maxNoImprovementTimes allowed number of successive iterations}.
241      */
242     private static final class ImprovementEvaluator {
243         /** Batch size. */
244         private final int batchSize;
245         /** Maximum number of iterations during which no improvement is occuring. */
246         private final int maxNoImprovementTimes;
247         /**
248          * <a href="https://en.wikipedia.org/wiki/Moving_average">
249          * Exponentially Weighted Average</a> of the squared
250          * diff to monitor the convergence while discarding
251          * minibatch-local stochastic variability.
252          */
253         private double ewaInertia = Double.NaN;
254         /** Minimum value of {@link #ewaInertia} during iteration. */
255         private double ewaInertiaMin = Double.POSITIVE_INFINITY;
256         /** Number of iteration during which {@link #ewaInertia} did not improve. */
257         private int noImprovementTimes;
258 
259         /**
260          * @param batchSize Number of elements for each batch iteration.
261          * @param maxNoImprovementTimes Maximum number of iterations during
262          * which no improvement is occuring.
263          */
264         private ImprovementEvaluator(int batchSize,
265                                      int maxNoImprovementTimes) {
266             this.batchSize = batchSize;
267             this.maxNoImprovementTimes = maxNoImprovementTimes;
268         }
269 
270         /**
271          * Stopping criterion.
272          *
273          * @param squareDistance Total square distance from the batch points
274          * to their nearest center.
275          * @param pointSize Number of data points.
276          * @return {@code true} if no improvement was made after the allowed
277          * number of iterations, {@code false} otherwise.
278          */
279         public boolean converge(final double squareDistance,
280                                 final int pointSize) {
281             final double batchInertia = squareDistance / batchSize;
282             if (Double.isNaN(ewaInertia)) {
283                 ewaInertia = batchInertia;
284             } else {
285                 final double alpha = Math.min(batchSize * 2 / (pointSize + 1), 1);
286                 ewaInertia = ewaInertia * (1 - alpha) + batchInertia * alpha;
287             }
288 
289             if (ewaInertia < ewaInertiaMin) {
290                 // Improved.
291                 noImprovementTimes = 0;
292                 ewaInertiaMin = ewaInertia;
293             } else {
294                 // No improvement.
295                 ++noImprovementTimes;
296             }
297 
298             return noImprovementTimes >= maxNoImprovementTimes;
299         }
300     }
301 }