1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.statistics.distribution;
18
19 import java.lang.reflect.Modifier;
20 import java.util.ArrayList;
21 import java.util.Arrays;
22 import java.util.Collections;
23 import java.util.Properties;
24 import java.util.function.Function;
25 import java.util.stream.IntStream;
26 import java.util.stream.Stream;
27 import org.apache.commons.math3.util.MathArrays;
28 import org.apache.commons.rng.simple.RandomSource;
29 import org.apache.commons.statistics.distribution.DistributionTestData.DiscreteDistributionTestData;
30 import org.junit.jupiter.api.Assertions;
31 import org.junit.jupiter.api.Assumptions;
32 import org.junit.jupiter.api.TestInstance;
33 import org.junit.jupiter.api.TestInstance.Lifecycle;
34 import org.junit.jupiter.params.ParameterizedTest;
35 import org.junit.jupiter.params.provider.Arguments;
36 import org.junit.jupiter.params.provider.MethodSource;
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201 @TestInstance(Lifecycle.PER_CLASS)
202 abstract class BaseDiscreteDistributionTest
203 extends BaseDistributionTest<DiscreteDistribution, DiscreteDistributionTestData> {
204
205
206 private static final int SUM_RANGE_TOO_LARGE = 50;
207
208 @Override
209 DiscreteDistributionTestData makeDistributionData(Properties properties) {
210 return new DiscreteDistributionTestData(properties);
211 }
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226 Stream<Arguments> streamCdfTestPoints(TestName name) {
227 return stream(name,
228 DiscreteDistributionTestData::getCdfPoints);
229 }
230
231
232
233
234
235
236
237 Stream<Arguments> testProbability() {
238 return stream(TestName.PMF,
239 DiscreteDistributionTestData::getPmfPoints,
240 DiscreteDistributionTestData::getPmfValues);
241 }
242
243
244
245
246
247
248
249 Stream<Arguments> testLogProbability() {
250 return stream(TestName.LOGPMF,
251 DiscreteDistributionTestData::getPmfPoints,
252 DiscreteDistributionTestData::getLogPmfValues);
253 }
254
255
256
257
258
259
260
261 Stream<Arguments> testCumulativeProbability() {
262 return stream(TestName.CDF,
263 DiscreteDistributionTestData::getCdfPoints,
264 DiscreteDistributionTestData::getCdfValues);
265 }
266
267
268
269
270
271
272
273
274
275 Stream<Arguments> testSurvivalProbability() {
276 return stream(TestName.SF,
277 DiscreteDistributionTestData::getSfPoints,
278 DiscreteDistributionTestData::getSfValues);
279 }
280
281
282
283
284
285
286
287 Stream<Arguments> testCumulativeProbabilityHighPrecision() {
288 return stream(TestName.CDF_HP,
289 DiscreteDistributionTestData::getCdfHpPoints,
290 DiscreteDistributionTestData::getCdfHpValues);
291 }
292
293
294
295
296
297
298
299 Stream<Arguments> testSurvivalProbabilityHighPrecision() {
300 return stream(TestName.SF_HP,
301 DiscreteDistributionTestData::getSfHpPoints,
302 DiscreteDistributionTestData::getSfHpValues);
303 }
304
305
306
307
308
309
310
311 Stream<Arguments> testInverseCumulativeProbability() {
312 return stream(TestName.ICDF,
313 DiscreteDistributionTestData::getIcdfPoints,
314 DiscreteDistributionTestData::getIcdfValues);
315 }
316
317
318
319
320
321
322
323 Stream<Arguments> testInverseSurvivalProbability() {
324 return stream(TestName.ISF,
325 DiscreteDistributionTestData::getIsfPoints,
326 DiscreteDistributionTestData::getIsfValues);
327 }
328
329
330
331
332
333
334 Stream<Arguments> testCumulativeProbabilityInverseMapping() {
335 return stream(TestName.CDF_MAPPING,
336 DiscreteDistributionTestData::getCdfPoints);
337 }
338
339
340
341
342
343
344 Stream<Arguments> testSurvivalProbabilityInverseMapping() {
345 return stream(TestName.SF_MAPPING,
346 DiscreteDistributionTestData::getSfPoints);
347 }
348
349
350
351
352
353
354
355 Stream<Arguments> testCumulativeProbabilityHighPrecisionInverseMapping() {
356 return stream(TestName.CDF_HP_MAPPING,
357 DiscreteDistributionTestData::getCdfHpPoints);
358 }
359
360
361
362
363
364
365
366 Stream<Arguments> testSurvivalProbabilityHighPrecisionInverseMapping() {
367 return stream(TestName.SF_HP_MAPPING,
368 DiscreteDistributionTestData::getSfHpPoints);
369 }
370
371
372
373
374
375
376
377 Stream<Arguments> testSurvivalAndCumulativeProbabilityComplement() {
378
379
380 return streamCdfTestPoints(TestName.COMPLEMENT);
381 }
382
383
384
385
386
387
388
389
390 Stream<Arguments> testConsistency() {
391
392
393 return streamCdfTestPoints(TestName.CONSISTENCY);
394 }
395
396
397
398
399
400
401 Stream<Arguments> testOutsideSupport() {
402 return stream(TestName.OUTSIDE_SUPPORT);
403 }
404
405
406
407
408
409
410
411 Stream<Arguments> testSamplingPMF() {
412 return stream(TestName.SAMPLING_PMF,
413 DiscreteDistributionTestData::getPmfPoints,
414 DiscreteDistributionTestData::getPmfValues);
415 }
416
417
418
419
420
421
422
423
424
425 Stream<Arguments> testSampling() {
426 return stream(TestName.SAMPLING);
427 }
428
429
430
431
432
433
434
435
436
437
438 Stream<Arguments> testProbabilitySums() {
439
440
441 final double scale = 10;
442 final TestName cdf = TestName.CDF;
443 final Function<DiscreteDistributionTestData, DoubleTolerance> tolerance =
444 d -> createAbsOrRelTolerance(d.getAbsoluteTolerance(cdf) * scale,
445 d.getRelativeTolerance(cdf) * scale);
446 final TestName name = TestName.PMF_SUM;
447 return stream(d -> d.isDisabled(name),
448 DiscreteDistributionTestData::getCdfPoints,
449 DiscreteDistributionTestData::getCdfValues,
450 tolerance, name.toString());
451 }
452
453
454
455
456
457
458
459 Stream<Arguments> testSupport() {
460 return streamArguments(TestName.SUPPORT,
461 d -> Arguments.of(namedDistribution(d.getParameters()), d.getLower(), d.getUpper()));
462 }
463
464
465
466
467
468
469
470 Stream<Arguments> testMoments() {
471 final TestName name = TestName.MOMENTS;
472 return streamArguments(name,
473 d -> Arguments.of(namedDistribution(d.getParameters()), d.getMean(), d.getVariance(),
474 createTestTolerance(d, name)));
475 }
476
477
478
479
480
481
482 Stream<Arguments> testMedian() {
483 return streamArguments(TestName.MEDIAN,
484 d -> Arguments.of(namedDistribution(d.getParameters())));
485 }
486
487
488
489
490
491
492
493
494
495
496
497
498 @ParameterizedTest
499 @MethodSource
500 final void testProbability(DiscreteDistribution dist,
501 int[] points,
502 double[] values,
503 DoubleTolerance tolerance) {
504 for (int i = 0; i < points.length; i++) {
505 final int x = points[i];
506 TestUtils.assertEquals(values[i],
507 dist.probability(x), tolerance,
508 () -> "Incorrect probability mass value returned for " + x);
509 }
510 }
511
512
513
514
515 @ParameterizedTest
516 @MethodSource
517 final void testLogProbability(DiscreteDistribution dist,
518 int[] points,
519 double[] values,
520 DoubleTolerance tolerance) {
521 for (int i = 0; i < points.length; i++) {
522 final int x = points[i];
523 TestUtils.assertEquals(values[i],
524 dist.logProbability(x), tolerance,
525 () -> "Incorrect log probability mass value returned for " + x);
526 }
527 }
528
529
530
531
532 @ParameterizedTest
533 @MethodSource
534 final void testCumulativeProbability(DiscreteDistribution dist,
535 int[] points,
536 double[] values,
537 DoubleTolerance tolerance) {
538
539 for (int i = 0; i < points.length; i++) {
540 final int x = points[i];
541 TestUtils.assertEquals(values[i],
542 dist.cumulativeProbability(x),
543 tolerance,
544 () -> "Incorrect cumulative probability value returned for " + x);
545 }
546 }
547
548
549
550
551 @ParameterizedTest
552 @MethodSource
553 final void testSurvivalProbability(DiscreteDistribution dist,
554 int[] points,
555 double[] values,
556 DoubleTolerance tolerance) {
557 for (int i = 0; i < points.length; i++) {
558 final double x = points[i];
559 TestUtils.assertEquals(
560 values[i],
561 dist.survivalProbability(points[i]),
562 tolerance,
563 () -> "Incorrect survival probability value returned for " + x);
564 }
565 }
566
567
568
569
570
571 @ParameterizedTest
572 @MethodSource
573 final void testCumulativeProbabilityHighPrecision(DiscreteDistribution dist,
574 int[] points,
575 double[] values,
576 DoubleTolerance tolerance) {
577 assertHighPrecision(tolerance, values);
578 testCumulativeProbability(dist, points, values, tolerance);
579 }
580
581
582
583
584
585 @ParameterizedTest
586 @MethodSource
587 final void testSurvivalProbabilityHighPrecision(DiscreteDistribution dist,
588 int[] points,
589 double[] values,
590 DoubleTolerance tolerance) {
591 assertHighPrecision(tolerance, values);
592 testSurvivalProbability(dist, points, values, tolerance);
593 }
594
595
596
597
598
599
600 @ParameterizedTest
601 @MethodSource
602 final void testInverseCumulativeProbability(DiscreteDistribution dist,
603 double[] points,
604 int[] values) {
605 final int lower = dist.getSupportLowerBound();
606 final int upper = dist.getSupportUpperBound();
607 for (int i = 0; i < points.length; i++) {
608 final int x = values[i];
609 if (x < lower || x > upper) {
610 continue;
611 }
612 final double p = points[i];
613 Assertions.assertEquals(
614 x,
615 dist.inverseCumulativeProbability(p),
616 () -> "Incorrect inverse cumulative probability value returned for " + p);
617 }
618 }
619
620
621
622
623
624
625 @ParameterizedTest
626 @MethodSource
627 final void testInverseSurvivalProbability(DiscreteDistribution dist,
628 double[] points,
629 int[] values) {
630 final int lower = dist.getSupportLowerBound();
631 final int upper = dist.getSupportUpperBound();
632 for (int i = 0; i < points.length; i++) {
633 final int x = values[i];
634 if (x < lower || x > upper) {
635 continue;
636 }
637 final double p = points[i];
638 Assertions.assertEquals(
639 x,
640 dist.inverseSurvivalProbability(p),
641 () -> "Incorrect inverse survival probability value returned for " + p);
642 }
643 }
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672 @ParameterizedTest
673 @MethodSource
674 final void testCumulativeProbabilityInverseMapping(DiscreteDistribution dist,
675 int[] points) {
676 final int lower = dist.getSupportLowerBound();
677 final int upper = dist.getSupportUpperBound();
678 for (int i = 0; i < points.length; i++) {
679 final int x = points[i];
680 if (x < lower || x > upper) {
681 continue;
682 }
683 final double p = dist.cumulativeProbability(x);
684 if ((int) p == p) {
685
686 continue;
687 }
688 final double x1 = dist.inverseCumulativeProbability(p);
689 Assertions.assertEquals(
690 x,
691 x1,
692 () -> "Incorrect CDF inverse value returned for " + p);
693
694 final double pp = Math.nextUp(p);
695 if (x != upper && pp != 1 && p != dist.cumulativeProbability(x + 1)) {
696 final double x2 = dist.inverseCumulativeProbability(pp);
697 Assertions.assertEquals(
698 x + 1,
699 x2,
700 () -> "Incorrect CDF inverse value returned for " + pp);
701 }
702
703
704 if (x != lower) {
705 final double pm1 = dist.cumulativeProbability(x - 1);
706 final double px = (pm1 + p) / 2;
707 if (px > pm1) {
708 final double xx = dist.inverseCumulativeProbability(px);
709 Assertions.assertEquals(
710 x,
711 xx,
712 () -> "Incorrect CDF inverse value returned for " + px);
713 }
714 }
715 }
716 }
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745 @ParameterizedTest
746 @MethodSource
747 final void testSurvivalProbabilityInverseMapping(DiscreteDistribution dist,
748 int[] points) {
749 final int lower = dist.getSupportLowerBound();
750 final int upper = dist.getSupportUpperBound();
751 for (int i = 0; i < points.length; i++) {
752 final int x = points[i];
753 if (x < lower || x > upper) {
754 continue;
755 }
756 final double p = dist.survivalProbability(x);
757 if ((int) p == p) {
758
759 continue;
760 }
761 final double x1 = dist.inverseSurvivalProbability(p);
762 Assertions.assertEquals(
763 x,
764 x1,
765 () -> "Incorrect SF inverse value returned for " + p);
766
767
768 final double pp = Math.nextDown(p);
769 if (x != upper && pp != 0 && p != dist.survivalProbability(x + 1)) {
770 final double x2 = dist.inverseSurvivalProbability(pp);
771 Assertions.assertEquals(
772 x + 1,
773 x2,
774 () -> "Incorrect SF inverse value returned for " + pp);
775 }
776
777
778 if (x != lower) {
779 final double pm1 = dist.survivalProbability(x - 1);
780 final double px = (pm1 + p) / 2;
781 if (px < pm1) {
782 final double xx = dist.inverseSurvivalProbability(px);
783 Assertions.assertEquals(
784 x,
785 xx,
786 () -> "Incorrect CDF inverse value returned for " + px);
787 }
788 }
789 }
790 }
791
792
793
794
795
796
797 @ParameterizedTest
798 @MethodSource
799 final void testCumulativeProbabilityHighPrecisionInverseMapping(
800 DiscreteDistribution dist,
801 int[] points) {
802 testCumulativeProbabilityInverseMapping(dist, points);
803 }
804
805
806
807
808
809
810 @ParameterizedTest
811 @MethodSource
812 final void testSurvivalProbabilityHighPrecisionInverseMapping(
813 DiscreteDistribution dist,
814 int[] points) {
815 testSurvivalProbabilityInverseMapping(dist, points);
816 }
817
818
819
820
821
822 @ParameterizedTest
823 @MethodSource
824 final void testSurvivalAndCumulativeProbabilityComplement(DiscreteDistribution dist,
825 int[] points,
826 DoubleTolerance tolerance) {
827 for (final int x : points) {
828 TestUtils.assertEquals(
829 1.0,
830 dist.survivalProbability(x) + dist.cumulativeProbability(x),
831 tolerance,
832 () -> "survival + cumulative probability were not close to 1.0 for " + x);
833 }
834 }
835
836
837
838
839
840
841 @ParameterizedTest
842 @MethodSource
843 final void testConsistency(DiscreteDistribution dist,
844 int[] points,
845 DoubleTolerance tolerance) {
846 final int upper = dist.getSupportUpperBound();
847 for (int i = 0; i < points.length; i++) {
848 final int x0 = points[i];
849
850
851 Assertions.assertEquals(
852 0.0,
853 dist.probability(x0, x0),
854 () -> "Non-zero probability(x, x) for " + x0);
855
856
857 if (x0 < upper) {
858 Assertions.assertEquals(
859 dist.probability(x0 + 1),
860 dist.probability(x0, x0 + 1),
861 () -> "probability(x + 1) != probability(x, x + 1) for " + x0);
862 }
863
864 final double cdf0 = dist.cumulativeProbability(x0);
865 final double sf0 = cdf0 >= 0.5 ? dist.survivalProbability(x0) : Double.NaN;
866 for (int j = 0; j < points.length; j++) {
867 final int x1 = points[j];
868
869
870 if (x0 + 1L < x1) {
871
872
873
874 double expected;
875 if (cdf0 >= 0.5) {
876 expected = sf0 - dist.survivalProbability(x1);
877 } else {
878 expected = dist.cumulativeProbability(x1) - cdf0;
879 }
880 TestUtils.assertEquals(
881 expected,
882 dist.probability(x0, x1),
883 tolerance,
884 () -> "Inconsistent probability for (" + x0 + "," + x1 + ")");
885 } else if (x0 > x1) {
886 Assertions.assertThrows(IllegalArgumentException.class,
887 () -> dist.probability(x0, x1),
888 "probability(int, int) should have thrown an exception that first argument is too large");
889 }
890 }
891 }
892 }
893
894
895
896
897
898 @ParameterizedTest
899 @MethodSource
900 final void testOutsideSupport(DiscreteDistribution dist,
901 DoubleTolerance tolerance) {
902
903 final int lo = dist.getSupportLowerBound();
904 TestUtils.assertEquals(dist.probability(lo), dist.cumulativeProbability(lo), tolerance, () -> "pmf(lower) != cdf(lower) for " + lo);
905 Assertions.assertEquals(lo, dist.inverseCumulativeProbability(-0.0), "icdf(-0.0)");
906 Assertions.assertEquals(lo, dist.inverseCumulativeProbability(0.0), "icdf(0.0)");
907 Assertions.assertEquals(lo, dist.inverseSurvivalProbability(1.0), "isf(1.0)");
908
909 if (lo != Integer.MIN_VALUE) {
910 final int below = lo - 1;
911 Assertions.assertEquals(0.0, dist.probability(below), "pmf(x < lower)");
912 Assertions.assertEquals(Double.NEGATIVE_INFINITY, dist.logProbability(below), "logpmf(x < lower)");
913 Assertions.assertEquals(0.0, dist.cumulativeProbability(below), "cdf(x < lower)");
914 Assertions.assertEquals(1.0, dist.survivalProbability(below), "sf(x < lower)");
915 }
916
917 final int hi = dist.getSupportUpperBound();
918 Assertions.assertTrue(lo <= hi, "lower <= upper");
919 Assertions.assertEquals(hi, dist.inverseCumulativeProbability(1.0), "icdf(1.0)");
920 Assertions.assertEquals(hi, dist.inverseSurvivalProbability(-0.0), "isf(-0.0)");
921 Assertions.assertEquals(hi, dist.inverseSurvivalProbability(0.0), "isf(0.0)");
922 if (hi != Integer.MAX_VALUE) {
923
924
925 Assertions.assertEquals(1.0, dist.cumulativeProbability(hi), "cdf(upper)");
926 Assertions.assertEquals(0.0, dist.survivalProbability(hi), "sf(upper)");
927 TestUtils.assertEquals(dist.probability(hi), dist.survivalProbability(hi - 1), tolerance, () -> "pmf(upper - 1) != sf(upper - 1) for " + hi);
928
929 final int above = hi + 1;
930 Assertions.assertEquals(0.0, dist.probability(above), "pmf(x > upper)");
931 Assertions.assertEquals(Double.NEGATIVE_INFINITY, dist.logProbability(above), "logpmf(x > upper)");
932 Assertions.assertEquals(1.0, dist.cumulativeProbability(above), "cdf(x > upper)");
933 Assertions.assertEquals(0.0, dist.survivalProbability(above), "sf(x > upper)");
934 }
935
936
937 assertPmfAndLogPmfAtBound(dist, lo, tolerance, "lower");
938 assertPmfAndLogPmfAtBound(dist, hi, tolerance, "upper");
939 }
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954 private static void assertPmfAndLogPmfAtBound(DiscreteDistribution dist, int x,
955 DoubleTolerance tolerance, String name) {
956
957
958 final double p = dist.probability(x);
959 final double logp = dist.logProbability(x);
960 if (p > Double.MIN_NORMAL) {
961 TestUtils.assertEquals(Math.log(p), logp, tolerance,
962 () -> String.format("%s: log(pmf(%d)) != logpmf(%d)", name, x, x));
963 } else {
964 TestUtils.assertEquals(p, Math.exp(logp), tolerance,
965 () -> String.format("%s: pmf(%d) != exp(logpmf(%d))", name, x, x));
966 }
967 }
968
969
970
971
972
973 @ParameterizedTest
974 @MethodSource(value = "streamDistribution")
975 final void testInvalidProbabilities(DiscreteDistribution dist) {
976 final int lo = dist.getSupportLowerBound();
977 final int hi = dist.getSupportUpperBound();
978 if (lo < hi) {
979 Assertions.assertThrows(DistributionException.class, () -> dist.probability(hi, lo), "x0 > x1");
980 }
981 Assertions.assertThrows(DistributionException.class, () -> dist.inverseCumulativeProbability(-1), "p < 0.0");
982 Assertions.assertThrows(DistributionException.class, () -> dist.inverseCumulativeProbability(2), "p > 1.0");
983 Assertions.assertThrows(DistributionException.class, () -> dist.inverseSurvivalProbability(-1), "q < 0.0");
984 Assertions.assertThrows(DistributionException.class, () -> dist.inverseSurvivalProbability(2), "q > 1.0");
985 }
986
987
988
989
990
991
992 @ParameterizedTest
993 @MethodSource
994 final void testSamplingPMF(DiscreteDistribution dist,
995 int[] points,
996 double[] values) {
997
998
999
1000
1001
1002
1003
1004 points = points.clone();
1005 values = values.clone();
1006 final int length = TestUtils.eliminateZeroMassPoints(points, values);
1007 final double[] expected = Arrays.copyOf(values, length);
1008
1009
1010
1011 final double sum = Arrays.stream(expected).sum();
1012 Assumptions.assumeTrue(sum > 0.5,
1013 () -> "Not enough of the PMF is tested during sampling: " + sum);
1014
1015
1016 final DiscreteDistribution.Sampler sampler =
1017 dist.createSampler(RandomSource.XO_SHI_RO_256_PP.create(1234567890L));
1018
1019
1020 if (length == 1) {
1021 final int point = points[0];
1022 for (int i = 0; i < 20; i++) {
1023 Assertions.assertEquals(point, sampler.sample());
1024 }
1025 return;
1026 }
1027
1028 final int sampleSize = 1000;
1029 MathArrays.scaleInPlace(sampleSize, expected);
1030
1031 final int[] sample = TestUtils.sample(sampleSize, sampler);
1032
1033 final long[] counts = new long[length];
1034 for (int i = 0; i < sampleSize; i++) {
1035 final int x = sample[i];
1036 for (int j = 0; j < length; j++) {
1037 if (x == points[j]) {
1038 counts[j]++;
1039 break;
1040 }
1041 }
1042 }
1043
1044 TestUtils.assertChiSquareAccept(points, expected, counts, 0.001);
1045 }
1046
1047
1048
1049
1050
1051
1052
1053
1054 @ParameterizedTest
1055 @MethodSource
1056 final void testSampling(DiscreteDistribution dist) {
1057 final int[] quartiles = TestUtils.getDistributionQuartiles(dist);
1058
1059
1060
1061 final double[] expected = {
1062 dist.cumulativeProbability(quartiles[0]),
1063 dist.probability(quartiles[0], quartiles[1]),
1064 dist.probability(quartiles[1], quartiles[2]),
1065 dist.survivalProbability(quartiles[2]),
1066 };
1067
1068
1069
1070
1071 final DoubleTolerance tolerance = DoubleTolerances.absolute(0.1);
1072 for (final double p : expected) {
1073 Assumptions.assumeTrue(tolerance.test(0.25, p),
1074 () -> "Unexpected quartiles: " + Arrays.toString(expected));
1075 }
1076
1077 final int sampleSize = 1000;
1078 MathArrays.scaleInPlace(sampleSize, expected);
1079
1080
1081 final DiscreteDistribution.Sampler sampler =
1082 dist.createSampler(RandomSource.XO_SHI_RO_256_PP.create(123456789L));
1083 final int[] sample = TestUtils.sample(sampleSize, sampler);
1084
1085 final long[] counts = new long[4];
1086 for (int i = 0; i < sampleSize; i++) {
1087 TestUtils.updateCounts(sample[i], counts, quartiles);
1088 }
1089
1090 TestUtils.assertChiSquareAccept(expected, counts, 0.001);
1091 }
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103 @ParameterizedTest
1104 @MethodSource
1105 final void testProbabilitySums(DiscreteDistribution dist,
1106 int[] points,
1107 double[] values,
1108 DoubleTolerance tolerance) {
1109 final ArrayList<Integer> integrationTestPoints = new ArrayList<>();
1110 for (int i = 0; i < points.length; i++) {
1111 if (Double.isNaN(values[i]) ||
1112 values[i] < 1e-5 ||
1113 values[i] > 1 - 1e-5) {
1114 continue;
1115 }
1116 integrationTestPoints.add(points[i]);
1117 }
1118 Collections.sort(integrationTestPoints);
1119
1120 for (int i = 1; i < integrationTestPoints.size(); i++) {
1121 final int x0 = integrationTestPoints.get(i - 1);
1122 final int x1 = integrationTestPoints.get(i);
1123
1124 if (x1 - x0 > SUM_RANGE_TOO_LARGE) {
1125 continue;
1126 }
1127 final double sum = IntStream.rangeClosed(x0 + 1, x1).mapToDouble(dist::probability).sum();
1128 TestUtils.assertEquals(dist.probability(x0, x1), sum, tolerance,
1129 () -> "Invalid probability sum: " + (x0 + 1) + " to " + x1);
1130 }
1131 }
1132
1133
1134
1135
1136 @ParameterizedTest
1137 @MethodSource
1138 final void testSupport(DiscreteDistribution dist, double lower, double upper) {
1139 Assertions.assertEquals(lower, dist.getSupportLowerBound(), "lower bound");
1140 Assertions.assertEquals(upper, dist.getSupportUpperBound(), "upper bound");
1141 }
1142
1143
1144
1145
1146 @ParameterizedTest
1147 @MethodSource
1148 final void testMoments(DiscreteDistribution dist, double mean, double variance, DoubleTolerance tolerance) {
1149 TestUtils.assertEquals(mean, dist.getMean(), tolerance, "mean");
1150 TestUtils.assertEquals(variance, dist.getVariance(), tolerance, "variance");
1151 }
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161 @ParameterizedTest
1162 @MethodSource
1163 final void testMedian(DiscreteDistribution dist) {
1164 if (dist instanceof AbstractDiscreteDistribution) {
1165 final AbstractDiscreteDistribution d = (AbstractDiscreteDistribution) dist;
1166 Assertions.assertEquals(d.inverseCumulativeProbability(0.5), d.getMedian(), "median");
1167 assertMethodNotModified(dist.getClass(), Modifier.PUBLIC | Modifier.PROTECTED, "getMedian");
1168 }
1169 }
1170 }