1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math4.legacy.ml.clustering.evaluation;
19
20 import static org.junit.Assert.assertEquals;
21 import static org.junit.Assert.assertFalse;
22 import static org.junit.Assert.assertTrue;
23
24 import java.util.ArrayList;
25 import java.util.List;
26
27 import org.apache.commons.math4.legacy.ml.clustering.Cluster;
28 import org.apache.commons.math4.legacy.ml.clustering.DoublePoint;
29 import org.apache.commons.math4.legacy.ml.clustering.ClusterEvaluator;
30 import org.apache.commons.math4.legacy.ml.distance.EuclideanDistance;
31 import org.junit.Before;
32 import org.junit.Test;
33
34 public class SumOfClusterVariancesTest {
35
36 private ClusterEvaluator evaluator;
37
38 @Before
39 public void setUp() {
40 evaluator = new SumOfClusterVariances(new EuclideanDistance());
41 }
42
43 @Test
44 public void testScore() {
45 final DoublePoint[] points1 = new DoublePoint[] {
46 new DoublePoint(new double[] { 1 }),
47 new DoublePoint(new double[] { 2 }),
48 new DoublePoint(new double[] { 3 })
49 };
50
51 final DoublePoint[] points2 = new DoublePoint[] {
52 new DoublePoint(new double[] { 1 }),
53 new DoublePoint(new double[] { 5 }),
54 new DoublePoint(new double[] { 10 })
55 };
56
57 final List<Cluster<DoublePoint>> clusters = new ArrayList<>();
58
59 final Cluster<DoublePoint> cluster1 = new Cluster<>();
60 for (DoublePoint p : points1) {
61 cluster1.addPoint(p);
62 }
63 clusters.add(cluster1);
64
65 assertEquals(1.0/3.0, evaluator.score(clusters), 1e-6);
66
67 final Cluster<DoublePoint> cluster2 = new Cluster<>();
68 for (DoublePoint p : points2) {
69 cluster2.addPoint(p);
70 }
71 clusters.add(cluster2);
72
73 assertEquals(6.148148148, evaluator.score(clusters), 1e-6);
74 }
75
76 @Test
77 public void testOrdering() {
78 assertTrue(evaluator.isBetterScore(10, 20));
79 assertFalse(evaluator.isBetterScore(20, 1));
80 }
81 }