1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.sampling.distribution;
18
19 import java.util.ArrayList;
20 import java.util.Arrays;
21 import java.util.List;
22 import org.apache.commons.math3.stat.inference.ChiSquareTest;
23 import org.junit.jupiter.api.Assertions;
24 import org.junit.jupiter.params.ParameterizedTest;
25 import org.junit.jupiter.params.provider.MethodSource;
26
27
28
29
30 class DiscreteSamplerParametricTest {
31 private static Iterable<DiscreteSamplerTestData> getSamplerTestData() {
32 return DiscreteSamplersList.list();
33 }
34
35 @ParameterizedTest
36 @MethodSource("getSamplerTestData")
37 void testSampling(DiscreteSamplerTestData data) {
38 final int sampleSize = 10000;
39
40 check(sampleSize,
41 data.getSampler(),
42 data.getPoints(),
43 data.getProbabilities());
44 }
45
46
47
48
49
50
51
52
53
54
55
56
57 private static void check(long sampleSize,
58 DiscreteSampler sampler,
59 int[] points,
60 double[] expected) {
61 final ChiSquareTest chiSquareTest = new ChiSquareTest();
62 final int numTests = 50;
63
64
65 int numFailures = 0;
66
67 final int numBins = points.length;
68 final long[] observed = new long[numBins];
69
70
71 final List<Double> failedStat = new ArrayList<>();
72 try {
73 for (int i = 0; i < numTests; i++) {
74 Arrays.fill(observed, 0);
75 SAMPLE: for (long j = 0; j < sampleSize; j++) {
76 final int value = sampler.sample();
77
78 for (int k = 0; k < numBins; k++) {
79 if (value == points[k]) {
80 ++observed[k];
81 continue SAMPLE;
82 }
83 }
84 }
85
86 final double p = chiSquareTest.chiSquareTest(expected, observed);
87 if (p < 0.01) {
88 failedStat.add(p);
89 ++numFailures;
90 }
91 }
92 } catch (Exception e) {
93
94 throw new RuntimeException("Unexpected", e);
95 }
96
97
98
99
100
101
102
103
104
105 if (numFailures > 3) {
106 Assertions.fail(String.format(
107 "%s: Too many failures for sample size = %d " +
108 " (%d out of %d tests failed, chi2=%s",
109 sampler, sampleSize, numFailures, numTests,
110 Arrays.toString(failedStat.toArray(new Double[0]))));
111 }
112 }
113 }