1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.sampling.distribution;
18
19 import org.apache.commons.math3.distribution.BinomialDistribution;
20 import org.apache.commons.math3.distribution.PoissonDistribution;
21 import org.apache.commons.math3.stat.inference.ChiSquareTest;
22 import org.apache.commons.rng.UniformRandomProvider;
23 import org.apache.commons.rng.sampling.RandomAssert;
24 import org.junit.jupiter.api.Assertions;
25 import org.junit.jupiter.api.Test;
26 import org.junit.jupiter.params.ParameterizedTest;
27 import org.junit.jupiter.params.provider.ValueSource;
28
29
30
31
32 class GuideTableDiscreteSamplerTest {
33 @Test
34 void testConstructorThrowsWithNullProbabilites() {
35 assertConstructorThrows(null, 1.0);
36 }
37
38 @Test
39 void testConstructorThrowsWithZeroLengthProbabilites() {
40 assertConstructorThrows(new double[0], 1.0);
41 }
42
43 @Test
44 void testConstructorThrowsWithNegativeProbabilites() {
45 assertConstructorThrows(new double[] {-1, 0.1, 0.2}, 1.0);
46 }
47
48 @Test
49 void testConstructorThrowsWithNaNProbabilites() {
50 assertConstructorThrows(new double[] {0.1, Double.NaN, 0.2}, 1.0);
51 }
52
53 @Test
54 void testConstructorThrowsWithInfiniteProbabilites() {
55 assertConstructorThrows(new double[] {0.1, Double.POSITIVE_INFINITY, 0.2}, 1.0);
56 }
57
58 @Test
59 void testConstructorThrowsWithInfiniteSumProbabilites() {
60 assertConstructorThrows(new double[] {Double.MAX_VALUE, Double.MAX_VALUE}, 1.0);
61 }
62
63 @Test
64 void testConstructorThrowsWithZeroSumProbabilites() {
65 assertConstructorThrows(new double[4], 1.0);
66 }
67
68 @Test
69 void testConstructorThrowsWithZeroAlpha() {
70 assertConstructorThrows(new double[] {0.5, 0.5}, 0.0);
71 }
72
73 @Test
74 void testConstructorThrowsWithNegativeAlpha() {
75 assertConstructorThrows(new double[] {0.5, 0.5}, -1.0);
76 }
77
78
79
80
81
82
83
84 private static void assertConstructorThrows(double[] probabilities, double alpha) {
85 final UniformRandomProvider rng = RandomAssert.seededRNG();
86 Assertions.assertThrows(IllegalArgumentException.class,
87 () -> GuideTableDiscreteSampler.of(rng, probabilities, alpha));
88 }
89
90 @Test
91 void testToString() {
92 final UniformRandomProvider rng = RandomAssert.createRNG();
93 final SharedStateDiscreteSampler sampler = GuideTableDiscreteSampler.of(rng, new double[] {0.5, 0.5}, 1.0);
94 Assertions.assertTrue(sampler.toString().toLowerCase().contains("guide table"));
95 }
96
97
98
99
100 @Test
101 void testBinomialSamples() {
102 final int trials = 67;
103 final double probabilityOfSuccess = 0.345;
104 final BinomialDistribution dist = new BinomialDistribution(null, trials, probabilityOfSuccess);
105 final double[] expected = new double[trials + 1];
106 for (int i = 0; i < expected.length; i++) {
107 expected[i] = dist.probability(i);
108 }
109 checkSamples(expected, 1.0);
110 }
111
112
113
114
115 @Test
116 void testPoissonSamples() {
117 final double mean = 3.14;
118 final PoissonDistribution dist = new PoissonDistribution(null, mean,
119 PoissonDistribution.DEFAULT_EPSILON, PoissonDistribution.DEFAULT_MAX_ITERATIONS);
120 final int maxN = dist.inverseCumulativeProbability(1 - 1e-6);
121 final double[] expected = new double[maxN];
122 for (int i = 0; i < expected.length; i++) {
123 expected[i] = dist.probability(i);
124 }
125 checkSamples(expected, 1.0);
126 }
127
128
129
130
131
132 @ParameterizedTest
133 @ValueSource(doubles = {1.0, 0.1, 10.0})
134 void testNonUniformSamplesWithProbabilities(double alpha) {
135 final double[] expected = {0.1, 0.2, 0.3, 0.1, 0.3};
136 checkSamples(expected, alpha);
137 }
138
139
140
141
142
143 @Test
144 void testNonUniformSamplesWithObservations() {
145 final double[] expected = {1, 2, 3, 1, 3};
146 checkSamples(expected, 1.0);
147 }
148
149
150
151
152
153 @Test
154 void testNonUniformSamplesWithZeroProbabilities() {
155 final double[] expected = {0.1, 0, 0.2, 0.3, 0.1, 0.3, 0};
156 checkSamples(expected, 1.0);
157 }
158
159
160
161
162
163 @Test
164 void testNonUniformSamplesWithZeroObservations() {
165 final double[] expected = {1, 2, 3, 0, 1, 3, 0};
166 checkSamples(expected, 1.0);
167 }
168
169
170
171
172
173 @Test
174 void testUniformSamplesWithNoObservationLessThanTheMean() {
175 final double[] expected = {2, 2, 2, 2, 2, 2};
176 checkSamples(expected, 1.0);
177 }
178
179
180
181
182
183
184
185
186
187
188 private static void checkSamples(double[] probabilies, double alpha) {
189 final UniformRandomProvider rng = RandomAssert.createRNG();
190 final SharedStateDiscreteSampler sampler = GuideTableDiscreteSampler.of(rng, probabilies, alpha);
191
192 final int numberOfSamples = 10000;
193 final long[] samples = new long[probabilies.length];
194 for (int i = 0; i < numberOfSamples; i++) {
195 samples[sampler.sample()]++;
196 }
197
198
199
200 int mapSize = 0;
201 for (int i = 0; i < probabilies.length; i++) {
202 if (probabilies[i] != 0) {
203 mapSize++;
204 }
205 }
206
207 final double[] expected = new double[mapSize];
208 final long[] observed = new long[mapSize];
209 for (int i = 0; i < probabilies.length; i++) {
210 if (probabilies[i] == 0) {
211 Assertions.assertEquals(0, samples[i], "No samples expected from zero probability");
212 } else {
213
214 --mapSize;
215 expected[mapSize] = probabilies[i];
216 observed[mapSize] = samples[i];
217 }
218 }
219
220 final ChiSquareTest chiSquareTest = new ChiSquareTest();
221
222 Assertions.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001));
223 }
224
225
226
227
228 @Test
229 void testSharedStateSampler() {
230 final UniformRandomProvider rng1 = RandomAssert.seededRNG();
231 final UniformRandomProvider rng2 = RandomAssert.seededRNG();
232 final double[] probabilities = {0.1, 0, 0.2, 0.3, 0.1, 0.3, 0};
233 final SharedStateDiscreteSampler sampler1 =
234 GuideTableDiscreteSampler.of(rng1, probabilities);
235 final SharedStateDiscreteSampler sampler2 = sampler1.withUniformRandomProvider(rng2);
236 RandomAssert.assertProduceSameSequence(sampler1, sampler2);
237 }
238 }