1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.sampling.distribution;
18
19 import java.util.Arrays;
20 import java.util.function.DoubleUnaryOperator;
21 import java.util.stream.Stream;
22 import org.apache.commons.math3.stat.descriptive.moment.Mean;
23 import org.apache.commons.math3.stat.inference.ChiSquareTest;
24 import org.apache.commons.rng.UniformRandomProvider;
25 import org.apache.commons.rng.sampling.RandomAssert;
26 import org.junit.jupiter.api.Assertions;
27 import org.junit.jupiter.api.Assumptions;
28 import org.junit.jupiter.api.Test;
29 import org.junit.jupiter.params.ParameterizedTest;
30 import org.junit.jupiter.params.provider.MethodSource;
31 import org.junit.jupiter.params.provider.ValueSource;
32
33
34
35
36 class FastLoadedDiceRollerDiscreteSamplerTest {
37
38
39
40
41
42
43 private static SharedStateDiscreteSampler createSampler(long... frequencies) {
44 final UniformRandomProvider rng = RandomAssert.createRNG();
45 return FastLoadedDiceRollerDiscreteSampler.of(rng, frequencies);
46 }
47
48
49
50
51
52
53
54 private static SharedStateDiscreteSampler createSampler(double... weights) {
55 final UniformRandomProvider rng = RandomAssert.createRNG();
56 return FastLoadedDiceRollerDiscreteSampler.of(rng, weights);
57 }
58
59
60
61
62
63
64 static Stream<long[]> testFactoryConstructorFrequencies() {
65 return Stream.of(
66
67 (long[]) null,
68 new long[0],
69
70 new long[] {-1, 2, 3},
71 new long[] {1, -2, 3},
72 new long[] {1, 2, -3},
73
74 new long[] {Long.MAX_VALUE, Long.MAX_VALUE},
75
76 new long[] {Long.MAX_VALUE, Long.MAX_VALUE, 2},
77
78 new long[] {Long.MAX_VALUE, Long.MAX_VALUE, Long.MAX_VALUE},
79
80 new long[1],
81 new long[4]
82 );
83 }
84
85 @ParameterizedTest
86 @MethodSource
87 void testFactoryConstructorFrequencies(long[] frequencies) {
88 Assertions.assertThrows(IllegalArgumentException.class, () -> createSampler(frequencies));
89 }
90
91
92
93
94
95
96 static Stream<double[]> testFactoryConstructorWeights() {
97 return Stream.of(
98
99 (double[]) null,
100 new double[0],
101
102 new double[] {-1, 2, 3},
103 new double[] {1, -2, 3},
104 new double[] {1, 2, -3},
105 new double[] {Double.POSITIVE_INFINITY, 2, 3},
106 new double[] {1, Double.POSITIVE_INFINITY, 3},
107 new double[] {1, 2, Double.POSITIVE_INFINITY},
108 new double[] {Double.NaN, 2, 3},
109 new double[] {1, Double.NaN, 3},
110 new double[] {1, 2, Double.NaN},
111
112 new double[1],
113 new double[4]
114 );
115 }
116
117 @ParameterizedTest
118 @MethodSource
119 void testFactoryConstructorWeights(double[] weights) {
120 Assertions.assertThrows(IllegalArgumentException.class, () -> createSampler(weights));
121 }
122
123 @Test
124 void testToString() {
125 for (final long[] observed : new long[][] {{42}, {1, 2, 3}}) {
126 final SharedStateDiscreteSampler sampler = createSampler(observed);
127 Assertions.assertTrue(sampler.toString().toLowerCase().contains("fast loaded dice roller"));
128 }
129 }
130
131 @Test
132 void testSingleCategory() {
133 final int n = 13;
134 final int[] expected = new int[n];
135 Assertions.assertArrayEquals(expected, createSampler(42).samples(n).toArray());
136 Assertions.assertArrayEquals(expected, createSampler(0.55).samples(n).toArray());
137 }
138
139 @Test
140 void testSingleFrequency() {
141 final long[] frequencies = new long[5];
142 final int category = 2;
143 frequencies[category] = 1;
144 final SharedStateDiscreteSampler sampler = createSampler(frequencies);
145 final int n = 7;
146 final int[] expected = new int[n];
147 Arrays.fill(expected, category);
148 Assertions.assertArrayEquals(expected, sampler.samples(n).toArray());
149 }
150
151 @Test
152 void testSingleWeight() {
153 final double[] weights = new double[5];
154 final int category = 3;
155 weights[category] = 1.5;
156 final SharedStateDiscreteSampler sampler = createSampler(weights);
157 final int n = 6;
158 final int[] expected = new int[n];
159 Arrays.fill(expected, category);
160 Assertions.assertArrayEquals(expected, sampler.samples(n).toArray());
161 }
162
163 @Test
164 void testIndexOfNonZero() {
165 Assertions.assertThrows(IllegalStateException.class,
166 () -> FastLoadedDiceRollerDiscreteSampler.indexOfNonZero(new long[3]));
167 final long[] data = new long[3];
168 for (int i = 0; i < data.length; i++) {
169 data[i] = 13;
170 Assertions.assertEquals(i, FastLoadedDiceRollerDiscreteSampler.indexOfNonZero(data));
171 data[i] = 0;
172 }
173 }
174
175 @ParameterizedTest
176 @ValueSource(longs = {0, 1, -1, Integer.MAX_VALUE, 1L << 34})
177 void testCheckArraySize(long size) {
178
179 final int max = Integer.MAX_VALUE - 8;
180
181
182 if (size > max) {
183 Assertions.assertThrows(IllegalArgumentException.class,
184 () -> FastLoadedDiceRollerDiscreteSampler.checkArraySize(size));
185 } else {
186 Assertions.assertEquals((int) size, FastLoadedDiceRollerDiscreteSampler.checkArraySize(size));
187 }
188 }
189
190
191
192
193
194
195 static Stream<long[]> testSamplesFrequencies() {
196 return Stream.of(
197
198 new long[] {0, 0, 42, 0, 0},
199
200 new long[] {1, 1, 2, 3, 1},
201 new long[] {0, 1, 1, 0, 2, 3, 1, 0},
202
203 new long[] {1, 2, 3, 1, 3},
204 new long[] {1, 0, 2, 0, 3, 1, 3},
205
206 new long[] {5126734627834L, 213267384684832L, 126781236718L, 71289979621378L}
207 );
208 }
209
210
211
212
213
214
215 @ParameterizedTest
216 @MethodSource
217 void testSamplesFrequencies(long[] expectedFrequencies) {
218 final SharedStateDiscreteSampler sampler = createSampler(expectedFrequencies);
219 final int numberOfSamples = 10000;
220 final long[] samples = new long[expectedFrequencies.length];
221 sampler.samples(numberOfSamples).forEach(x -> samples[x]++);
222
223
224 int mapSize = 0;
225 double sum = 0;
226 for (final double f : expectedFrequencies) {
227 if (f != 0) {
228 mapSize++;
229 sum += f;
230 }
231 }
232
233
234 if (mapSize == 1) {
235 int index = 0;
236 while (index < expectedFrequencies.length) {
237 if (expectedFrequencies[index] != 0) {
238 break;
239 }
240 index++;
241 }
242 Assertions.assertEquals(numberOfSamples, samples[index], "Invalid single category samples");
243 return;
244 }
245
246 final double[] expected = new double[mapSize];
247 final long[] observed = new long[mapSize];
248 for (int i = 0; i < expectedFrequencies.length; i++) {
249 if (expectedFrequencies[i] != 0) {
250 --mapSize;
251 expected[mapSize] = expectedFrequencies[i] / sum;
252 observed[mapSize] = samples[i];
253 } else {
254 Assertions.assertEquals(0, samples[i], "No samples expected from zero probability");
255 }
256 }
257
258 final ChiSquareTest chiSquareTest = new ChiSquareTest();
259
260 Assertions.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001));
261 }
262
263
264
265
266
267
268 static Stream<double[]> testSamplesWeights() {
269 return Stream.of(
270
271 new double[] {0, 0, 0.523, 0, 0},
272
273 new double[] {0.125, 0.125, 0.25, 0.375, 0.125},
274 new double[] {0, 0.125, 0.125, 0.25, 0, 0.375, 0.125, 0},
275
276 new double[] {0.1, 0.2, 0.3, 0.1, 0.3},
277 new double[] {0.1, 0, 0.2, 0, 0.3, 0.1, 0.3},
278
279 new double[] {5 * Double.MIN_NORMAL, 2 * Double.MIN_NORMAL, 3 * Double.MIN_NORMAL, 9 * Double.MIN_NORMAL},
280 new double[] {2 * Double.MIN_NORMAL, Double.MIN_NORMAL, 0.5 * Double.MIN_NORMAL, 0.75 * Double.MIN_NORMAL},
281 new double[] {Double.MIN_VALUE, 2 * Double.MIN_VALUE, 3 * Double.MIN_VALUE, 7 * Double.MIN_VALUE},
282
283 new double[] {1.0, 2.0, Math.scalb(3.0, -32), Math.scalb(4.0, -65), 5.0},
284 new double[] {Math.scalb(1.0, 35), Math.scalb(2.0, 35), Math.scalb(3.0, -32), Math.scalb(4.0, -65), Math.scalb(5.0, 35)},
285
286 new double[] {Double.MAX_VALUE, Double.MAX_VALUE, Double.MAX_VALUE / 2, Double.MAX_VALUE / 4}
287 );
288 }
289
290
291
292
293
294
295 @ParameterizedTest
296 @MethodSource
297 void testSamplesWeights(double[] weights) {
298 final SharedStateDiscreteSampler sampler = createSampler(weights);
299 final int numberOfSamples = 10000;
300 final long[] samples = new long[weights.length];
301 sampler.samples(numberOfSamples).forEach(x -> samples[x]++);
302
303
304 int mapSize = 0;
305 double sum = 0;
306
307 final Mean mean = new Mean();
308 for (final double w : weights) {
309 if (w != 0) {
310 mapSize++;
311 sum += w;
312 mean.increment(w);
313 }
314 }
315
316
317 if (mapSize == 1) {
318 int index = 0;
319 while (index < weights.length) {
320 if (weights[index] != 0) {
321 break;
322 }
323 index++;
324 }
325 Assertions.assertEquals(numberOfSamples, samples[index], "Invalid single category samples");
326 return;
327 }
328
329 final double mu = mean.getResult();
330 final int n = mapSize;
331 final double s = sum;
332 final DoubleUnaryOperator normalise = Double.isInfinite(sum) ?
333 x -> (x / mu) * n :
334 x -> x / s;
335
336 final double[] expected = new double[mapSize];
337 final long[] observed = new long[mapSize];
338 for (int i = 0; i < weights.length; i++) {
339 if (weights[i] != 0) {
340 --mapSize;
341 expected[mapSize] = normalise.applyAsDouble(weights[i]);
342 observed[mapSize] = samples[i];
343 } else {
344 Assertions.assertEquals(0, samples[i], "No samples expected from zero probability");
345 }
346 }
347
348 final ChiSquareTest chiSquareTest = new ChiSquareTest();
349
350 Assertions.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001));
351 }
352
353
354
355
356
357
358
359 static Stream<long[]> testSamplesWeightsMatchesFrequencies() {
360
361
362 return testSamplesFrequencies();
363 }
364
365
366
367
368
369
370
371 @ParameterizedTest
372 @MethodSource
373 void testSamplesWeightsMatchesFrequencies(long[] frequencies) {
374 final double[] weights = new double[frequencies.length];
375 for (int i = 0; i < frequencies.length; i++) {
376 final double w = frequencies[i];
377 Assumptions.assumeTrue((long) w == frequencies[i]);
378
379 weights[i] = Math.scalb(w, -35);
380 }
381 final UniformRandomProvider[] rngs = RandomAssert.createRNG(2);
382 final UniformRandomProvider rng1 = rngs[0];
383 final UniformRandomProvider rng2 = rngs[1];
384 final SharedStateDiscreteSampler sampler1 =
385 FastLoadedDiceRollerDiscreteSampler.of(rng1, frequencies);
386 final SharedStateDiscreteSampler sampler2 =
387 FastLoadedDiceRollerDiscreteSampler.of(rng2, weights);
388 RandomAssert.assertProduceSameSequence(sampler1, sampler2);
389 }
390
391
392
393
394
395
396
397
398
399
400
401 @ParameterizedTest
402 @ValueSource(ints = {1023, 67, 1, -59, -1020, -1021})
403 void testScaledWeights(int scaleFactor) {
404
405 final double[] w1 = RandomAssert.createRNG().doubles(10).toArray();
406 final double scale = Math.scalb(1.0, scaleFactor);
407 final double[] w2 = Arrays.stream(w1).map(x -> x * scale).toArray();
408 final UniformRandomProvider[] rngs = RandomAssert.createRNG(2);
409 final UniformRandomProvider rng1 = rngs[0];
410 final UniformRandomProvider rng2 = rngs[1];
411 final SharedStateDiscreteSampler sampler1 =
412 FastLoadedDiceRollerDiscreteSampler.of(rng1, w1);
413 final SharedStateDiscreteSampler sampler2 =
414 FastLoadedDiceRollerDiscreteSampler.of(rng2, w2);
415 RandomAssert.assertProduceSameSequence(sampler1, sampler2);
416 }
417
418
419
420
421
422
423
424
425 @ParameterizedTest
426 @ValueSource(ints = {13, 30, 53})
427 void testAlphaRemovesWeights(int alpha) {
428
429 final double small = Math.scalb(1.0, -(alpha + 1));
430 final double[] w1 = {1, 0.5, 0.5, 0};
431 final double[] w2 = {1, 0.5, 0.5, small};
432 final UniformRandomProvider[] rngs = RandomAssert.createRNG(3);
433 final UniformRandomProvider rng1 = rngs[0];
434 final UniformRandomProvider rng2 = rngs[1];
435 final UniformRandomProvider rng3 = rngs[2];
436
437 final int n = 10;
438 final int[] s1 = FastLoadedDiceRollerDiscreteSampler.of(rng1, w1).samples(n).toArray();
439 final int[] s2 = FastLoadedDiceRollerDiscreteSampler.of(rng2, w2, alpha).samples(n).toArray();
440 final int[] s3 = FastLoadedDiceRollerDiscreteSampler.of(rng3, w2, alpha + 1).samples(n).toArray();
441
442 Assertions.assertArrayEquals(s1, s2, "alpha parameter should ignore the small weight");
443 Assertions.assertFalse(Arrays.equals(s1, s3), "alpha+1 parameter should not ignore the small weight");
444 }
445
446 static Stream<long[]> testSharedStateSampler() {
447 return Stream.of(
448 new long[] {42},
449 new long[] {1, 1, 2, 3, 1}
450 );
451 }
452
453 @ParameterizedTest
454 @MethodSource
455 void testSharedStateSampler(long[] frequencies) {
456 final UniformRandomProvider rng1 = RandomAssert.seededRNG();
457 final UniformRandomProvider rng2 = RandomAssert.seededRNG();
458 final SharedStateDiscreteSampler sampler1 =
459 FastLoadedDiceRollerDiscreteSampler.of(rng1, frequencies);
460 final SharedStateDiscreteSampler sampler2 = sampler1.withUniformRandomProvider(rng2);
461 RandomAssert.assertProduceSameSequence(sampler1, sampler2);
462 }
463 }