1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.sampling.distribution;
18
19 import java.util.ArrayList;
20 import java.util.Arrays;
21 import java.util.List;
22 import java.util.stream.Collectors;
23 import org.junit.jupiter.api.Assertions;
24 import org.junit.jupiter.params.ParameterizedTest;
25 import org.junit.jupiter.params.provider.MethodSource;
26
27
28
29
30 class ContinuousSamplerParametricTest {
31 private static Iterable<ContinuousSamplerTestData> getSamplerTestData() {
32 return ContinuousSamplersList.list();
33 }
34
35 @ParameterizedTest
36 @MethodSource("getSamplerTestData")
37 void testSampling(ContinuousSamplerTestData data) {
38 check(20000, data.getSampler(), data.getDeciles());
39 }
40
41
42
43
44
45
46
47
48
49
50
51 private static void check(long sampleSize,
52 ContinuousSampler sampler,
53 double[] deciles) {
54 final int numTests = 50;
55
56
57 final int numBins = 10;
58
59
60 int numFailures = 0;
61
62 final double[] expected = new double[numBins];
63 Arrays.fill(expected, sampleSize / (double) numBins);
64
65 final long[] observed = new long[numBins];
66
67
68 final double chi2CriticalValue = 21.665994333461924;
69
70
71 final List<Double> failedStat = new ArrayList<>();
72 try {
73 final int lastDecileIndex = numBins - 1;
74 for (int i = 0; i < numTests; i++) {
75 Arrays.fill(observed, 0);
76 SAMPLE: for (long j = 0; j < sampleSize; j++) {
77 final double value = sampler.sample();
78
79 for (int k = 0; k < lastDecileIndex; k++) {
80 if (value < deciles[k]) {
81 ++observed[k];
82 continue SAMPLE;
83 }
84 }
85 ++observed[lastDecileIndex];
86 }
87
88
89 double chi2 = 0;
90 for (int k = 0; k < numBins; k++) {
91 final double diff = observed[k] - expected[k];
92 chi2 += diff * diff / expected[k];
93 }
94
95
96 if (chi2 > chi2CriticalValue) {
97 failedStat.add(chi2);
98 ++numFailures;
99 }
100 }
101 } catch (Exception e) {
102
103 throw new RuntimeException("Unexpected", e);
104 }
105
106
107
108
109
110
111
112
113
114 if (numFailures > 3) {
115 Assertions.fail(String.format(
116 "%s: Too many failures for sample size = %d " +
117 "(%d out of %d tests failed, chi2 > %.3f=%s)",
118 sampler, sampleSize, numFailures, numTests, chi2CriticalValue,
119 failedStat.stream().map(d -> String.format("%.3f", d))
120 .collect(Collectors.joining(", ", "[", "]"))));
121 }
122 }
123 }