MiniBatchKMeansClusterer.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.commons.math4.legacy.ml.clustering;

import org.apache.commons.math4.legacy.exception.NullArgumentException;
import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
import org.apache.commons.math4.legacy.ml.distance.DistanceMeasure;
import org.apache.commons.math4.legacy.core.Pair;
import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.sampling.ListSampler;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

/**
 * Clustering algorithm <a href="https://www.eecs.tufts.edu/~dsculley/papers/fastkmeans.pdf">
 * based on KMeans</a>.
 *
 * @param <T> Type of the points to cluster.
 */
public class MiniBatchKMeansClusterer<T extends Clusterable>
    extends KMeansPlusPlusClusterer<T> {
    /** Batch data size in iteration. */
    private final int batchSize;
    /** Iteration count of initialize the centers. */
    private final int initIterations;
    /** Data size of batch to initialize the centers. */
    private final int initBatchSize;
    /** Maximum number of iterations during which no improvement is occuring. */
    private final int maxNoImprovementTimes;


    /**
     * Build a clusterer.
     *
     * @param k Number of clusters to split the data into.
     * @param maxIterations Maximum number of iterations to run the algorithm for all the points,
     * The actual number of iterationswill be smaller than {@code maxIterations * size / batchSize},
     * where {@code size} is the number of points to cluster.
     * Disabled if negative.
     * @param batchSize Batch size for training iterations.
     * @param initIterations Number of iterations allowed in order to find out the best initial centers.
     * @param initBatchSize Batch size for initializing the clusters centers.
     * A value of {@code 3 * batchSize} should be suitable in most cases.
     * @param maxNoImprovementTimes Maximum number of iterations during which no improvement is occuring.
     * A value of 10 is suitable in most cases.
     * @param measure Distance measure.
     * @param random Random generator.
     * @param emptyStrategy Strategy for handling empty clusters that may appear during algorithm iterations.
     */
    public MiniBatchKMeansClusterer(final int k,
                                    final int maxIterations,
                                    final int batchSize,
                                    final int initIterations,
                                    final int initBatchSize,
                                    final int maxNoImprovementTimes,
                                    final DistanceMeasure measure,
                                    final UniformRandomProvider random,
                                    final EmptyClusterStrategy emptyStrategy) {
        super(k, maxIterations, measure, random, emptyStrategy);

        if (batchSize < 1) {
            throw new NumberIsTooSmallException(batchSize, 1, true);
        }
        if (initIterations < 1) {
            throw new NumberIsTooSmallException(initIterations, 1, true);
        }
        if (initBatchSize < 1) {
            throw new NumberIsTooSmallException(initBatchSize, 1, true);
        }
        if (maxNoImprovementTimes < 1) {
            throw new NumberIsTooSmallException(maxNoImprovementTimes, 1, true);
        }

        this.batchSize = batchSize;
        this.initIterations = initIterations;
        this.initBatchSize = initBatchSize;
        this.maxNoImprovementTimes = maxNoImprovementTimes;
    }

    /**
     * Runs the MiniBatch K-means clustering algorithm.
     *
     * @param points Points to cluster (cannot be {@code null}).
     * @return the clusters.
     * @throws org.apache.commons.math4.legacy.exception.MathIllegalArgumentException
     * if the number of points is smaller than the number of clusters.
     */
    @Override
    public List<CentroidCluster<T>> cluster(final Collection<T> points) {
        // Sanity check.
        NullArgumentException.check(points);
        if (points.size() < getNumberOfClusters()) {
            throw new NumberIsTooSmallException(points.size(), getNumberOfClusters(), false);
        }

        final int pointSize = points.size();
        final int batchCount = pointSize / batchSize + (pointSize % batchSize > 0 ? 1 : 0);
        final int max = getMaxIterations() < 0 ?
            Integer.MAX_VALUE :
            getMaxIterations() * batchCount;

        final List<T> pointList = new ArrayList<>(points);
        List<CentroidCluster<T>> clusters = initialCenters(pointList);

        final ImprovementEvaluator evaluator = new ImprovementEvaluator(batchSize,
                                                                        maxNoImprovementTimes);
        for (int i = 0; i < max; i++) {
            clearClustersPoints(clusters);
            final List<T> batchPoints = ListSampler.sample(getRandomGenerator(), pointList, batchSize);
            // Training step.
            final Pair<Double, List<CentroidCluster<T>>> pair = step(batchPoints, clusters);
            final double squareDistance = pair.getFirst();
            clusters = pair.getSecond();
            // Check whether the training can finished early.
            if (evaluator.converge(squareDistance, pointSize)) {
                break;
            }
        }

        // Add every mini batch points to their nearest cluster.
        clearClustersPoints(clusters);
        for (final T point : points) {
            addToNearestCentroidCluster(point, clusters);
        }

        return clusters;
    }

    /**
     * Helper method.
     *
     * @param clusters Clusters to clear.
     */
    private void clearClustersPoints(final List<CentroidCluster<T>> clusters) {
        for (CentroidCluster<T> cluster : clusters) {
            cluster.getPoints().clear();
        }
    }

    /**
     * Mini batch iteration step.
     *
     * @param batchPoints Points selected for this batch.
     * @param clusters Centers of the clusters.
     * @return the squared distance of all the batch points to the nearest center.
     */
    private Pair<Double, List<CentroidCluster<T>>> step(final List<T> batchPoints,
                                                        final List<CentroidCluster<T>> clusters) {
        // Add every mini batch points to their nearest cluster.
        for (final T point : batchPoints) {
            addToNearestCentroidCluster(point, clusters);
        }
        final List<CentroidCluster<T>> newClusters = adjustClustersCenters(clusters);
        // Add every mini batch points to their nearest cluster again.
        double squareDistance = 0.0;
        for (T point : batchPoints) {
            final double d = addToNearestCentroidCluster(point, newClusters);
            squareDistance += d * d;
        }

        return new Pair<>(squareDistance, newClusters);
    }

    /**
     * Initializes the clusters centers.
     *
     * @param points Points used to initialize the centers.
     * @return clusters with their center initialized.
     */
    private List<CentroidCluster<T>> initialCenters(final List<T> points) {
        final List<T> validPoints = initBatchSize < points.size() ?
            ListSampler.sample(getRandomGenerator(), points, initBatchSize) :
            new ArrayList<>(points);
        double nearestSquareDistance = Double.POSITIVE_INFINITY;
        List<CentroidCluster<T>> bestCenters = null;

        for (int i = 0; i < initIterations; i++) {
            final List<T> initialPoints = (initBatchSize < points.size()) ?
                ListSampler.sample(getRandomGenerator(), points, initBatchSize) :
                new ArrayList<>(points);
            final List<CentroidCluster<T>> clusters = chooseInitialCenters(initialPoints);
            final Pair<Double, List<CentroidCluster<T>>> pair = step(validPoints, clusters);
            final double squareDistance = pair.getFirst();
            final List<CentroidCluster<T>> newClusters = pair.getSecond();
            //Find out a best centers that has the nearest total square distance.
            if (squareDistance < nearestSquareDistance) {
                nearestSquareDistance = squareDistance;
                bestCenters = newClusters;
            }
        }
        return bestCenters;
    }

    /**
     * Adds a point to the cluster whose center is closest.
     *
     * @param point Point to add.
     * @param clusters Clusters.
     * @return the distance between point and the closest center.
     */
    private double addToNearestCentroidCluster(final T point,
                                               final List<CentroidCluster<T>> clusters) {
        double minDistance = Double.POSITIVE_INFINITY;
        CentroidCluster<T> closestCentroidCluster = null;

        // Find cluster closest to the point.
        for (CentroidCluster<T> centroidCluster : clusters) {
            final double distance = distance(point, centroidCluster.getCenter());
            if (distance < minDistance) {
                minDistance = distance;
                closestCentroidCluster = centroidCluster;
            }
        }
        NullArgumentException.check(closestCentroidCluster);
        closestCentroidCluster.addPoint(point);

        return minDistance;
    }

    /**
     * Stopping criterion.
     * The evaluator checks whether improvement occurred during the
     * {@link #maxNoImprovementTimes allowed number of successive iterations}.
     */
    private static final class ImprovementEvaluator {
        /** Batch size. */
        private final int batchSize;
        /** Maximum number of iterations during which no improvement is occuring. */
        private final int maxNoImprovementTimes;
        /**
         * <a href="https://en.wikipedia.org/wiki/Moving_average">
         * Exponentially Weighted Average</a> of the squared
         * diff to monitor the convergence while discarding
         * minibatch-local stochastic variability.
         */
        private double ewaInertia = Double.NaN;
        /** Minimum value of {@link #ewaInertia} during iteration. */
        private double ewaInertiaMin = Double.POSITIVE_INFINITY;
        /** Number of iteration during which {@link #ewaInertia} did not improve. */
        private int noImprovementTimes;

        /**
         * @param batchSize Number of elements for each batch iteration.
         * @param maxNoImprovementTimes Maximum number of iterations during
         * which no improvement is occuring.
         */
        private ImprovementEvaluator(int batchSize,
                                     int maxNoImprovementTimes) {
            this.batchSize = batchSize;
            this.maxNoImprovementTimes = maxNoImprovementTimes;
        }

        /**
         * Stopping criterion.
         *
         * @param squareDistance Total square distance from the batch points
         * to their nearest center.
         * @param pointSize Number of data points.
         * @return {@code true} if no improvement was made after the allowed
         * number of iterations, {@code false} otherwise.
         */
        public boolean converge(final double squareDistance,
                                final int pointSize) {
            final double batchInertia = squareDistance / batchSize;
            if (Double.isNaN(ewaInertia)) {
                ewaInertia = batchInertia;
            } else {
                final double alpha = Math.min(batchSize * 2 / (pointSize + 1), 1);
                ewaInertia = ewaInertia * (1 - alpha) + batchInertia * alpha;
            }

            if (ewaInertia < ewaInertiaMin) {
                // Improved.
                noImprovementTimes = 0;
                ewaInertiaMin = ewaInertia;
            } else {
                // No improvement.
                ++noImprovementTimes;
            }

            return noImprovementTimes >= maxNoImprovementTimes;
        }
    }
}