CalinskiHarabasz.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.commons.math4.legacy.ml.clustering.evaluation;

import org.apache.commons.math4.legacy.exception.InsufficientDataException;
import org.apache.commons.math4.legacy.ml.clustering.Cluster;
import org.apache.commons.math4.legacy.ml.clustering.ClusterEvaluator;
import org.apache.commons.math4.legacy.ml.clustering.Clusterable;
import org.apache.commons.math4.legacy.core.MathArrays;

import java.util.Collection;
import java.util.List;

/**
 * Compute the Calinski and Harabasz score.
 * <p>
 * It is also known as the Variance Ratio Criterion.
 * <p>
 * The score is defined as ratio between the within-cluster dispersion and
 * the between-cluster dispersion.
 *
 * @see <a href="https://www.tandfonline.com/doi/abs/10.1080/03610927408827101">A dendrite method for cluster
 * analysis</a>
 */
public class CalinskiHarabasz implements ClusterEvaluator {
    /** {@inheritDoc} */
    @Override
    public double score(List<? extends Cluster<? extends Clusterable>> clusters) {
        final int dimension = dimensionOfClusters(clusters);
        final double[] centroid = meanOfClusters(clusters, dimension);

        double intraDistanceProduct = 0.0;
        double extraDistanceProduct = 0.0;
        for (Cluster<? extends Clusterable> cluster : clusters) {
            // Calculate the center of the cluster.
            double[] clusterCentroid = mean(cluster.getPoints(), dimension);
            for (Clusterable p : cluster.getPoints()) {
                // Increase the intra distance sum
                intraDistanceProduct += covariance(clusterCentroid, p.getPoint());
            }
            // Increase the extra distance sum
            extraDistanceProduct += cluster.getPoints().size() * covariance(centroid, clusterCentroid);
        }

        final int pointCount = countAllPoints(clusters);
        final int clusterCount = clusters.size();
        // Return the ratio of the intraDistranceProduct to extraDistanceProduct
        return intraDistanceProduct == 0.0 ? 1.0 :
                (extraDistanceProduct * (pointCount - clusterCount) /
                        (intraDistanceProduct * (clusterCount - 1)));
    }

    /** {@inheritDoc} */
    @Override
    public boolean isBetterScore(double a,
                                 double b) {
        return a > b;
    }

    /**
     * Calculate covariance of two double array.
     * <pre>
     *   covariance = sum((p1[i]-p2[i])^2)
     * </pre>
     *
     * @param p1 Double array
     * @param p2 Double array
     * @return covariance of two double array
     */
    private double covariance(double[] p1, double[] p2) {
        MathArrays.checkEqualLength(p1, p2);
        double sum = 0;
        for (int i = 0; i < p1.length; i++) {
            final double dp = p1[i] - p2[i];
            sum += dp * dp;
        }
        return sum;
    }

    /**
     * Calculate the mean of all the points.
     *
     * @param points    A collection of points
     * @param dimension The dimension of each point
     * @return The mean value.
     */
    private double[] mean(final Collection<? extends Clusterable> points, final int dimension) {
        final double[] centroid = new double[dimension];
        for (final Clusterable p : points) {
            final double[] point = p.getPoint();
            for (int i = 0; i < centroid.length; i++) {
                centroid[i] += point[i];
            }
        }
        for (int i = 0; i < centroid.length; i++) {
            centroid[i] /= points.size();
        }
        return centroid;
    }

    /**
     * Calculate the mean of all the points in the clusters.
     *
     * @param clusters  A collection of clusters.
     * @param dimension The dimension of each point.
     * @return The mean value.
     */
    private double[] meanOfClusters(final Collection<? extends Cluster<? extends Clusterable>> clusters, final int dimension) {
        final double[] centroid = new double[dimension];
        int allPointsCount = 0;
        for (Cluster<? extends Clusterable> cluster : clusters) {
            for (Clusterable p : cluster.getPoints()) {
                double[] point = p.getPoint();
                for (int i = 0; i < centroid.length; i++) {
                    centroid[i] += point[i];
                }
                allPointsCount++;
            }
        }
        for (int i = 0; i < centroid.length; i++) {
            centroid[i] /= allPointsCount;
        }
        return centroid;
    }

    /**
     * Count all the points in collection of cluster.
     *
     * @param clusters collection of cluster
     * @return points count
     */
    private int countAllPoints(final Collection<? extends Cluster<? extends Clusterable>> clusters) {
        int pointCount = 0;
        for (Cluster<? extends Clusterable> cluster : clusters) {
            pointCount += cluster.getPoints().size();
        }
        return pointCount;
    }

    /**
     * Detect the dimension of points in the clusters.
     *
     * @param clusters collection of cluster
     * @return The dimension of the first point in clusters
     */
    private int dimensionOfClusters(final Collection<? extends Cluster<? extends Clusterable>> clusters) {
        // Iteration and find out the first point.
        for (Cluster<? extends Clusterable> cluster : clusters) {
            for (Clusterable p : cluster.getPoints()) {
                return p.getPoint().length;
            }
        }
        // Throw exception if there is no point.
        throw new InsufficientDataException();
    }
}