1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.sampling.distribution;
18
19 import org.apache.commons.math3.stat.inference.ChiSquareTest;
20 import org.apache.commons.rng.UniformRandomProvider;
21 import org.apache.commons.rng.core.source32.IntProvider;
22 import org.apache.commons.rng.core.source64.SplitMix64;
23 import org.apache.commons.rng.sampling.RandomAssert;
24 import org.junit.jupiter.api.Assertions;
25 import org.junit.jupiter.api.Test;
26
27
28
29
30
31
32
33 class MarsagliaTsangWangDiscreteSamplerTest {
34 @Test
35 void testCreateDiscreteDistributionThrowsWithNullProbabilites() {
36 assertEnumeratedSamplerConstructorThrows(null);
37 }
38
39 @Test
40 void testCreateDiscreteDistributionThrowsWithZeroLengthProbabilites() {
41 assertEnumeratedSamplerConstructorThrows(new double[0]);
42 }
43
44 @Test
45 void testCreateDiscreteDistributionThrowsWithNegativeProbabilites() {
46 assertEnumeratedSamplerConstructorThrows(new double[] {-1, 0.1, 0.2});
47 }
48
49 @Test
50 void testCreateDiscreteDistributionThrowsWithNaNProbabilites() {
51 assertEnumeratedSamplerConstructorThrows(new double[] {0.1, Double.NaN, 0.2});
52 }
53
54 @Test
55 void testCreateDiscreteDistributionThrowsWithInfiniteProbabilites() {
56 assertEnumeratedSamplerConstructorThrows(new double[] {0.1, Double.POSITIVE_INFINITY, 0.2});
57 }
58
59 @Test
60 void testCreateDiscreteDistributionThrowsWithInfiniteSumProbabilites() {
61 assertEnumeratedSamplerConstructorThrows(new double[] {Double.MAX_VALUE, Double.MAX_VALUE});
62 }
63
64 @Test
65 void testCreateDiscreteDistributionThrowsWithZeroSumProbabilites() {
66 assertEnumeratedSamplerConstructorThrows(new double[4]);
67 }
68
69
70
71
72
73
74 private static void assertEnumeratedSamplerConstructorThrows(double[] probabilities) {
75 final UniformRandomProvider rng = new SplitMix64(0L);
76 Assertions.assertThrows(IllegalArgumentException.class,
77 () -> MarsagliaTsangWangDiscreteSampler.Enumerated.of(rng, probabilities));
78 }
79
80
81
82
83 @Test
84 void testToString() {
85 final UniformRandomProvider rng = new SplitMix64(0L);
86 final DiscreteSampler sampler = MarsagliaTsangWangDiscreteSampler.Enumerated.of(rng, new double[] {0.5, 0.5});
87 String text = sampler.toString();
88 for (String item : new String[] {"Marsaglia", "Tsang", "Wang"}) {
89 Assertions.assertTrue(text.contains(item), () -> "toString missing: " + item);
90 }
91 }
92
93
94
95
96
97
98
99 @Test
100 void testOffsetSamples() {
101
102
103 final int[] prob = new int[6];
104 prob[0] = 1;
105 prob[1] = 1 + 1 << 6;
106 prob[2] = 1 + 1 << 12;
107 prob[3] = 1 + 1 << 18;
108 prob[4] = 1 + 1 << 24;
109
110 prob[5] = (1 << 30) - (prob[0] + prob[1] + prob[2] + prob[3] + prob[4]);
111
112
113
114 int n1 = 0;
115 int n2 = 0;
116 int n3 = 0;
117 int n4 = 0;
118 for (final int m : prob) {
119 n1 += getBase64Digit(m, 1);
120 n2 += getBase64Digit(m, 2);
121 n3 += getBase64Digit(m, 3);
122 n4 += getBase64Digit(m, 4);
123 }
124
125 final int t1 = n1 << 24;
126 final int t2 = t1 + (n2 << 18);
127 final int t3 = t2 + (n3 << 12);
128 final int t4 = t3 + (n4 << 6);
129
130
131 final int[] values = new int[] {0, t1, t2, t3, t4, 0xffffffff};
132 for (int i = 0; i < values.length; i++) {
133 values[i] <<= 2;
134 }
135
136 final UniformRandomProvider rng1 = new FixedSequenceIntProvider(values);
137 final UniformRandomProvider rng2 = new FixedSequenceIntProvider(values);
138 final UniformRandomProvider rng3 = new FixedSequenceIntProvider(values);
139
140
141 final int offset1 = 1;
142 final int offset2 = 1 << 8;
143 final int offset3 = 1 << 16;
144
145 final double[] p1 = createProbabilities(offset1, prob);
146 final double[] p2 = createProbabilities(offset2, prob);
147 final double[] p3 = createProbabilities(offset3, prob);
148
149 final SharedStateDiscreteSampler sampler1 = MarsagliaTsangWangDiscreteSampler.Enumerated.of(rng1, p1);
150 final SharedStateDiscreteSampler sampler2 = MarsagliaTsangWangDiscreteSampler.Enumerated.of(rng2, p2);
151 final SharedStateDiscreteSampler sampler3 = MarsagliaTsangWangDiscreteSampler.Enumerated.of(rng3, p3);
152
153 for (int i = 0; i < values.length; i++) {
154
155 final int s1 = sampler1.sample() - offset1;
156 final int s2 = sampler2.sample() - offset2;
157 final int s3 = sampler3.sample() - offset3;
158 Assertions.assertEquals(s1, s2, "Offset sample 1 and 2 do not match");
159 Assertions.assertEquals(s1, s3, "Offset Sample 1 and 3 do not match");
160 }
161 }
162
163
164
165
166
167
168
169
170 private static double[] createProbabilities(int offset, int[] prob) {
171 double[] probabilities = new double[offset + prob.length];
172 for (int i = 0; i < prob.length; i++) {
173 probabilities[i + offset] = prob[i];
174 }
175 return probabilities;
176 }
177
178
179
180
181 @Test
182 void testRealProbabilityDistributionSamples() {
183
184 final double[] probabilities = new double[11];
185 final UniformRandomProvider rng = RandomAssert.createRNG();
186 for (int i = 0; i < probabilities.length; i++) {
187 probabilities[i] = rng.nextDouble();
188 }
189
190
191 final UniformRandomProvider dummyRng = new FixedSequenceIntProvider(new int[] {0xffffffff});
192 final SharedStateDiscreteSampler dummySampler = MarsagliaTsangWangDiscreteSampler.Enumerated.of(dummyRng, probabilities);
193
194 dummySampler.sample();
195
196
197 final SharedStateDiscreteSampler sampler = MarsagliaTsangWangDiscreteSampler.Enumerated.of(rng, probabilities);
198
199 final int numberOfSamples = 10000;
200 final long[] samples = new long[probabilities.length];
201 for (int i = 0; i < numberOfSamples; i++) {
202 samples[sampler.sample()]++;
203 }
204
205 final ChiSquareTest chiSquareTest = new ChiSquareTest();
206
207 Assertions.assertFalse(chiSquareTest.chiSquareTest(probabilities, samples, 0.001));
208 }
209
210
211
212
213
214 @Test
215 void testStorageRequirements8() {
216
217
218
219
220
221
222 checkStorageRequirements(8, 0.06);
223 }
224
225
226
227
228
229 @Test
230 void testStorageRequirements16() {
231
232
233
234
235
236
237 checkStorageRequirements(16, 17.0);
238 }
239
240
241
242
243
244
245
246
247 private static void checkStorageRequirements(int k, double expectedLimitMB) {
248
249
250
251 final int maxSamples = 1 << k;
252
253
254
255
256 final int m = (1 << (30 - k)) - 1;
257
258
259 final long sum = (long) maxSamples * m;
260 final int total = 1 << 30;
261 Assertions.assertTrue(sum < total, "Worst case uniform distribution is above 2^30");
262
263
264 final int d1 = getBase64Digit(m, 1);
265 final int d2 = getBase64Digit(m, 2);
266 final int d3 = getBase64Digit(m, 3);
267 final int d4 = getBase64Digit(m, 4);
268 final int d5 = getBase64Digit(m, 5);
269
270 int bytes;
271 if (k <= 8) {
272 bytes = 1;
273 } else if (k <= 16) {
274 bytes = 2;
275 } else {
276 bytes = 4;
277 }
278 final double storageMB = bytes * 1e-6 * (d1 + d2 + d3 + d4 + d5) * maxSamples;
279 Assertions.assertTrue(storageMB < expectedLimitMB,
280 () -> "Worst case uniform distribution storage " + storageMB +
281 "MB is above expected limit: " + expectedLimitMB);
282 }
283
284
285
286
287
288
289
290
291 private static int getBase64Digit(int m, int k) {
292 return (m >>> (30 - 6 * k)) & 63;
293 }
294
295
296
297
298 @Test
299 void testCreatePoissonDistributionThrowsWithMeanLargerThanUpperBound() {
300 final UniformRandomProvider rng = new FixedRNG();
301 final double mean = 1025;
302 Assertions.assertThrows(IllegalArgumentException.class,
303 () -> MarsagliaTsangWangDiscreteSampler.Poisson.of(rng, mean));
304 }
305
306
307
308
309 @Test
310 void testCreatePoissonDistributionThrowsWithZeroMean() {
311 final UniformRandomProvider rng = new FixedRNG();
312 final double mean = 0;
313 Assertions.assertThrows(IllegalArgumentException.class,
314 () -> MarsagliaTsangWangDiscreteSampler.Poisson.of(rng, mean));
315 }
316
317
318
319
320 @Test
321 void testCreatePoissonDistributionWithMaximumMean() {
322 final UniformRandomProvider rng = new FixedRNG();
323 final double mean = 1024;
324 final DiscreteSampler sampler = MarsagliaTsangWangDiscreteSampler.Poisson.of(rng, mean);
325
326
327 sampler.sample();
328 }
329
330
331
332
333
334 @Test
335 void testCreatePoissonDistributionWithSmallMean() {
336 final UniformRandomProvider rng = new FixedRNG();
337 final double mean = 0.25;
338 final DiscreteSampler sampler = MarsagliaTsangWangDiscreteSampler.Poisson.of(rng, mean);
339
340
341 sampler.sample();
342 }
343
344
345
346
347
348
349 @Test
350 void testCreatePoissonDistributionWithMediumMean() {
351 final UniformRandomProvider rng = new FixedRNG();
352 final double mean = 21.4;
353 final DiscreteSampler sampler = MarsagliaTsangWangDiscreteSampler.Poisson.of(rng, mean);
354
355
356 sampler.sample();
357 }
358
359
360
361
362 @Test
363 void testCreateBinomialDistributionThrowsWithTrialsBelow0() {
364 final UniformRandomProvider rng = new FixedRNG();
365 final int trials = -1;
366 final double p = 0.5;
367 Assertions.assertThrows(IllegalArgumentException.class,
368 () -> MarsagliaTsangWangDiscreteSampler.Binomial.of(rng, trials, p));
369 }
370
371
372
373
374 @Test
375 void testCreateBinomialDistributionThrowsWithTrialsAboveMax() {
376 final UniformRandomProvider rng = new FixedRNG();
377 final int trials = 1 << 16;
378 final double p = 0.5;
379 Assertions.assertThrows(IllegalArgumentException.class,
380 () -> MarsagliaTsangWangDiscreteSampler.Binomial.of(rng, trials, p));
381 }
382
383
384
385
386 @Test
387 void testCreateBinomialDistributionThrowsWithProbabilityBelow0() {
388 final UniformRandomProvider rng = new FixedRNG();
389 final int trials = 1;
390 final double p = -0.5;
391 Assertions.assertThrows(IllegalArgumentException.class,
392 () -> MarsagliaTsangWangDiscreteSampler.Binomial.of(rng, trials, p));
393 }
394
395
396
397
398 @Test
399 void testCreateBinomialDistributionThrowsWithProbabilityAbove1() {
400 final UniformRandomProvider rng = new FixedRNG();
401 final int trials = 1;
402 final double p = 1.5;
403 Assertions.assertThrows(IllegalArgumentException.class,
404 () -> MarsagliaTsangWangDiscreteSampler.Binomial.of(rng, trials, p));
405 }
406
407
408
409
410
411 @Test
412 void testCreateBinomialDistributionWithSmallestP0ValueAndHighestProbabilityOfSuccess() {
413 final UniformRandomProvider rng = new FixedRNG();
414
415
416
417
418
419
420
421
422 final int trials = (int) Math.floor(Math.log(Double.MIN_VALUE) / Math.log(0.5));
423 final double p = 0.5;
424
425 Assertions.assertEquals(Double.MIN_VALUE, getBinomialP0(trials, p), 0, "Invalid test set-up for p(0)");
426 Assertions.assertEquals(0, getBinomialP0(trials + 1, p), 0, "Invalid test set-up for p(0)");
427
428
429 final DiscreteSampler sampler = MarsagliaTsangWangDiscreteSampler.Binomial.of(rng, trials, p);
430 sampler.sample();
431 }
432
433
434
435
436
437 @Test
438 void testCreateBinomialDistributionThrowsWhenP0IsZero() {
439 final UniformRandomProvider rng = new FixedRNG();
440
441 final int trials = 1 + (int) Math.floor(Math.log(Double.MIN_VALUE) / Math.log(0.5));
442 final double p = 0.5;
443
444 Assertions.assertEquals(0, getBinomialP0(trials, p), 0, "Invalid test set-up for p(0)");
445 Assertions.assertThrows(IllegalArgumentException.class,
446 () -> MarsagliaTsangWangDiscreteSampler.Binomial.of(rng, trials, p));
447 }
448
449
450
451
452
453 @Test
454 void testCreateBinomialDistributionWithLargestTrialsAndSmallestProbabilityOfSuccess() {
455 final UniformRandomProvider rng = new FixedRNG();
456
457
458
459
460
461
462
463
464 final int trials = (1 << 16) - 1;
465 double p = 1 - Math.exp(Math.log(Double.MIN_VALUE) / trials);
466
467
468 Assertions.assertEquals(Double.MIN_VALUE, getBinomialP0(trials, p), 0, "Invalid test set-up for p(0)");
469
470
471 double upper = p * 2;
472 Assertions.assertEquals(0, getBinomialP0(trials, upper), 0, "Invalid test set-up for p(0)");
473
474 double lower = p;
475 while (Double.doubleToRawLongBits(lower) + 1 < Double.doubleToRawLongBits(upper)) {
476 final double mid = (upper + lower) / 2;
477 if (getBinomialP0(trials, mid) == 0) {
478 upper = mid;
479 } else {
480 lower = mid;
481 }
482 }
483 p = lower;
484
485
486 Assertions.assertEquals(Double.MIN_VALUE, getBinomialP0(trials, p), 0, "Invalid test set-up for p(0)");
487 Assertions.assertEquals(0, getBinomialP0(trials, Math.nextUp(p)), 0, "Invalid test set-up for p(0)");
488
489 final DiscreteSampler sampler = MarsagliaTsangWangDiscreteSampler.Binomial.of(rng, trials, p);
490
491 sampler.sample();
492 }
493
494
495
496
497
498
499
500
501 private static double getBinomialP0(int trials, double probabilityOfSuccess) {
502 return Math.exp(trials * Math.log(1 - probabilityOfSuccess));
503 }
504
505
506
507
508 @Test
509 void testCreateBinomialDistributionWithProbability0() {
510 final UniformRandomProvider rng = new FixedRNG();
511 final int trials = 1000000;
512 final double p = 0;
513 final DiscreteSampler sampler = MarsagliaTsangWangDiscreteSampler.Binomial.of(rng, trials, p);
514 for (int i = 0; i < 5; i++) {
515 Assertions.assertEquals(0, sampler.sample());
516 }
517
518 Assertions.assertTrue(sampler.toString().contains("Binomial"));
519 }
520
521
522
523
524
525 @Test
526 void testCreateBinomialDistributionWithProbability1() {
527 final UniformRandomProvider rng = new FixedRNG();
528 final int trials = 1000000;
529 final double p = 1;
530 final DiscreteSampler sampler = MarsagliaTsangWangDiscreteSampler.Binomial.of(rng, trials, p);
531 for (int i = 0; i < 5; i++) {
532 Assertions.assertEquals(trials, sampler.sample());
533 }
534
535 Assertions.assertTrue(sampler.toString().contains("Binomial"));
536 }
537
538
539
540
541
542
543 @Test
544 void testCreateBinomialDistributionWithLargeNumberOfTrials() {
545 final UniformRandomProvider rng = new FixedRNG();
546 final int trials = 65000;
547 final double p = 0.01;
548 final DiscreteSampler sampler = MarsagliaTsangWangDiscreteSampler.Binomial.of(rng, trials, p);
549
550
551 sampler.sample();
552 }
553
554
555
556
557
558 @Test
559 void testCreateBinomialDistributionWithProbability50Percent() {
560 final UniformRandomProvider rng = new FixedRNG();
561 final int trials = 10;
562 final double p = 0.5;
563 final DiscreteSampler sampler = MarsagliaTsangWangDiscreteSampler.Binomial.of(rng, trials, p);
564
565
566 sampler.sample();
567 }
568
569
570
571
572
573 @Test
574 void testBinomialSamplerToString() {
575 final UniformRandomProvider rng = new FixedRNG();
576 final int trials = 10;
577 final double p1 = 0.4;
578 final double p2 = 1 - p1;
579 final DiscreteSampler sampler1 = MarsagliaTsangWangDiscreteSampler.Binomial.of(rng, trials, p1);
580 final DiscreteSampler sampler2 = MarsagliaTsangWangDiscreteSampler.Binomial.of(rng, trials, p2);
581 Assertions.assertEquals(sampler1.toString(), sampler2.toString());
582 }
583
584
585
586
587 @Test
588 void testSharedStateSamplerWith8bitStorage() {
589 testSharedStateSampler(0, new int[] {1, 2, 3, 4, 5});
590 }
591
592
593
594
595 @Test
596 void testSharedStateSamplerWith16bitStorage() {
597 testSharedStateSampler(1 << 8, new int[] {1, 2, 3, 4, 5});
598 }
599
600
601
602
603 @Test
604 void testSharedStateSamplerWith32bitStorage() {
605 testSharedStateSampler(1 << 16, new int[] {1, 2, 3, 4, 5});
606 }
607
608
609
610
611
612
613
614
615 private static void testSharedStateSampler(int offset, int[] prob) {
616 final UniformRandomProvider rng1 = RandomAssert.seededRNG();
617 final UniformRandomProvider rng2 = RandomAssert.seededRNG();
618 double[] probabilities = createProbabilities(offset, prob);
619 final SharedStateDiscreteSampler sampler1 =
620 MarsagliaTsangWangDiscreteSampler.Enumerated.of(rng1, probabilities);
621 final SharedStateDiscreteSampler sampler2 = sampler1.withUniformRandomProvider(rng2);
622 RandomAssert.assertProduceSameSequence(sampler1, sampler2);
623 }
624
625
626
627
628 @Test
629 void testSharedStateSamplerWithFixedBinomialDistribution() {
630 testSharedStateSampler(10, 1.0);
631 }
632
633
634
635
636
637 @Test
638 void testSharedStateSamplerWithInvertedBinomialDistribution() {
639 testSharedStateSampler(10, 0.999);
640 }
641
642
643
644
645
646
647
648
649 private static void testSharedStateSampler(int trials, double probabilityOfSuccess) {
650 final UniformRandomProvider rng1 = RandomAssert.seededRNG();
651 final UniformRandomProvider rng2 = RandomAssert.seededRNG();
652 final SharedStateDiscreteSampler sampler1 =
653 MarsagliaTsangWangDiscreteSampler.Binomial.of(rng1, trials, probabilityOfSuccess);
654 final SharedStateDiscreteSampler sampler2 = sampler1.withUniformRandomProvider(rng2);
655 RandomAssert.assertProduceSameSequence(sampler1, sampler2);
656 }
657
658
659
660
661 private static class FixedSequenceIntProvider extends IntProvider {
662
663 private int count;
664
665 private final int[] values;
666
667
668
669
670
671
672 FixedSequenceIntProvider(int[] values) {
673 this.values = values;
674 }
675
676 @Override
677 public int next() {
678
679 return values[count++ % values.length];
680 }
681 }
682
683
684
685
686 private static class FixedRNG extends IntProvider {
687 @Override
688 public int next() {
689 return 0xffffffff;
690 }
691 }
692 }