1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.sampling.distribution;
18
19 import java.util.ArrayList;
20 import java.util.Arrays;
21 import java.util.List;
22 import java.util.stream.Collectors;
23 import org.junit.jupiter.api.Assertions;
24 import org.junit.jupiter.params.ParameterizedTest;
25 import org.junit.jupiter.params.provider.MethodSource;
26
27
28
29
30 class ContinuousSamplerParametricTest {
31 private static Iterable<ContinuousSamplerTestData> getSamplerTestData() {
32 return ContinuousSamplersList.list();
33 }
34
35 @ParameterizedTest
36 @MethodSource("getSamplerTestData")
37 void testSampling(ContinuousSamplerTestData data) {
38 check(20000, data.getSampler(), data.getDeciles());
39 }
40
41
42
43
44
45
46
47
48
49
50
51 private static void check(long sampleSize,
52 ContinuousSampler sampler,
53 double[] deciles) {
54 final int numTests = 50;
55
56
57 final int numBins = 10;
58
59
60 int numFailures = 0;
61
62 final double expected = sampleSize / (double) numBins;
63
64 final long[] observed = new long[numBins];
65
66
67 final double chi2CriticalValue = 21.665994333461924;
68
69
70 final List<Double> failedStat = new ArrayList<>();
71 try {
72 final int lastDecileIndex = numBins - 1;
73 for (int i = 0; i < numTests; i++) {
74 Arrays.fill(observed, 0);
75 SAMPLE: for (long j = 0; j < sampleSize; j++) {
76 final double value = sampler.sample();
77
78 for (int k = 0; k < lastDecileIndex; k++) {
79 if (value < deciles[k]) {
80 ++observed[k];
81 continue SAMPLE;
82 }
83 }
84 ++observed[lastDecileIndex];
85 }
86
87
88 double chi2 = 0;
89 for (int k = 0; k < numBins; k++) {
90 final double diff = observed[k] - expected;
91 chi2 += diff * diff / expected;
92 }
93
94
95 if (chi2 > chi2CriticalValue) {
96 failedStat.add(chi2);
97 ++numFailures;
98 }
99 }
100 } catch (Exception e) {
101
102 throw new RuntimeException("Unexpected", e);
103 }
104
105
106
107
108
109
110
111
112
113 if (numFailures > 3) {
114 Assertions.fail(String.format(
115 "%s: Too many failures for sample size = %d " +
116 "(%d out of %d tests failed, chi2 > %.3f=%s)",
117 sampler, sampleSize, numFailures, numTests, chi2CriticalValue,
118 failedStat.stream().map(d -> String.format("%.3f", d))
119 .collect(Collectors.joining(", ", "[", "]"))));
120 }
121 }
122 }