View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math4.legacy.ml.clustering;
19  
20  import java.util.Collection;
21  import java.util.List;
22  
23  import org.apache.commons.math4.legacy.ml.clustering.evaluation.SumOfClusterVariances;
24  
25  /**
26   * A wrapper around a k-means++ clustering algorithm which performs multiple trials
27   * and returns the best solution.
28   * @param <T> type of the points to cluster
29   * @since 3.2
30   */
31  public class MultiKMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T> {
32  
33      /** The underlying k-means clusterer. */
34      private final KMeansPlusPlusClusterer<T> clusterer;
35  
36      /** The number of trial runs. */
37      private final int numTrials;
38  
39      /** The cluster evaluator to use. */
40      private final ClusterRanking evaluator;
41  
42      /** Build a clusterer.
43       * @param clusterer the k-means clusterer to use
44       * @param numTrials number of trial runs
45       */
46      public MultiKMeansPlusPlusClusterer(final KMeansPlusPlusClusterer<T> clusterer,
47                                          final int numTrials) {
48          this(clusterer,
49               numTrials,
50               ClusterEvaluator.ranking(new SumOfClusterVariances(clusterer.getDistanceMeasure())));
51      }
52  
53      /** Build a clusterer.
54       * @param clusterer the k-means clusterer to use
55       * @param numTrials number of trial runs
56       * @param evaluator the cluster evaluator to use
57       * @since 3.3
58       */
59      public MultiKMeansPlusPlusClusterer(final KMeansPlusPlusClusterer<T> clusterer,
60                                          final int numTrials,
61                                          final ClusterRanking evaluator) {
62          super(clusterer.getDistanceMeasure());
63          this.clusterer = clusterer;
64          this.numTrials = numTrials;
65          this.evaluator = evaluator;
66      }
67  
68      /**
69       * Runs the K-means++ clustering algorithm.
70       *
71       * @param points the points to cluster
72       * @return a list of clusters containing the points
73       * @throws org.apache.commons.math4.legacy.exception.MathIllegalArgumentException if
74       * the data points are null or the number of clusters is larger than the
75       * number of data points
76       * @throws org.apache.commons.math4.legacy.exception.ConvergenceException if
77       * an empty cluster is encountered and the underlying {@link KMeansPlusPlusClusterer}
78       * has its {@link KMeansPlusPlusClusterer.EmptyClusterStrategy} is set to {@code ERROR}.
79       */
80      @Override
81      public List<CentroidCluster<T>> cluster(final Collection<T> points) {
82          // at first, we have not found any clusters list yet
83          List<CentroidCluster<T>> best = null;
84          double bestRank = Double.NEGATIVE_INFINITY;
85  
86          // do several clustering trials
87          for (int i = 0; i < numTrials; ++i) {
88  
89              // compute a clusters list
90              List<CentroidCluster<T>> clusters = clusterer.cluster(points);
91  
92              // compute the rank of the current list
93              final double rank = evaluator.compute(clusters);
94  
95              if (rank > bestRank) {
96                  // this one is the best we have found so far, remember it
97                  best = clusters;
98                  bestRank = rank;
99              }
100         }
101 
102         // return the best clusters list found
103         return best;
104     }
105 }