1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.statistics.distribution;
18
19 import java.util.concurrent.ThreadLocalRandom;
20 import org.apache.commons.rng.UniformRandomProvider;
21 import org.junit.jupiter.api.Assertions;
22 import org.junit.jupiter.api.Test;
23 import org.junit.jupiter.params.ParameterizedTest;
24 import org.junit.jupiter.params.provider.ValueSource;
25
26
27
28
29 class ContinuousDistributionTest {
30
31
32
33 @Test
34 void testDefaultMethods() {
35 final double high = 0.54;
36 final double low = 0.313;
37
38 final ContinuousDistribution dist = new ContinuousDistribution() {
39 @Override
40 public double inverseCumulativeProbability(double p) {
41
42 return 10 * p;
43 }
44 @Override
45 public double getVariance() {
46 return 0;
47 }
48 @Override
49 public double getSupportUpperBound() {
50 return 0;
51 }
52 @Override
53 public double getSupportLowerBound() {
54 return 0;
55 }
56 @Override
57 public double getMean() {
58 return 0;
59 }
60 @Override
61 public double density(double x) {
62
63 return x;
64 }
65 @Override
66 public double cumulativeProbability(double x) {
67
68 return x > 1 ? high : low;
69 }
70 @Override
71 public Sampler createSampler(UniformRandomProvider rng) {
72 return null;
73 }
74 };
75
76 for (final double x : new double[] {Double.NaN, Double.POSITIVE_INFINITY,
77 Double.NEGATIVE_INFINITY, 0, 1, 0.123}) {
78
79 Assertions.assertEquals(Math.log(x), dist.logDensity(x));
80 }
81
82
83 Assertions.assertThrows(DistributionException.class, () -> dist.probability(0.5, 0.4));
84 Assertions.assertEquals(high - low, dist.probability(0.5, 1.5));
85 Assertions.assertEquals(high - low, dist.probability(0.5, 1.5));
86 for (final double p : new double[] {0.2, 0.5, 0.7}) {
87 Assertions.assertEquals(dist.inverseCumulativeProbability(1 - p),
88 dist.inverseSurvivalProbability(p));
89 }
90 }
91
92
93
94
95
96
97 @ParameterizedTest
98 @ValueSource(longs = {0, 1, 13})
99 void testSamplerStreamMethods(long streamSize) {
100 final double seed = ThreadLocalRandom.current().nextDouble();
101 final ContinuousDistribution.Sampler s1 = createIncrementSampler(seed);
102 final ContinuousDistribution.Sampler s2 = createIncrementSampler(seed);
103 final ContinuousDistribution.Sampler s3 = createIncrementSampler(seed);
104
105 final double[] x = new double[(int) streamSize];
106 for (int i = 0; i < x.length; i++) {
107 x[i] = s1.sample();
108 }
109
110 Assertions.assertArrayEquals(x, s2.samples().limit(streamSize).toArray(), "samples()");
111 Assertions.assertArrayEquals(x, s3.samples(streamSize).toArray(), "samples(long)");
112 }
113
114
115
116
117
118
119 @ParameterizedTest
120 @ValueSource(longs = {-1, -6576237846822L})
121 void testSamplerStreamMethodsThrow(long streamSize) {
122 final ContinuousDistribution.Sampler s = createIncrementSampler(42);
123 Assertions.assertThrows(IllegalArgumentException.class, () -> s.samples(streamSize));
124 }
125
126
127
128
129 @Test
130 void testSamplerStreamMethodsNotParallel() {
131 final ContinuousDistribution.Sampler s = createIncrementSampler(42);
132 Assertions.assertFalse(s.samples().isParallel(), "samples() should not be parallel");
133 Assertions.assertFalse(s.samples(11).isParallel(), "samples(long) should not be parallel");
134 }
135
136
137
138
139
140
141
142
143 private static ContinuousDistribution.Sampler createIncrementSampler(double seed) {
144 return new ContinuousDistribution.Sampler() {
145 private double x = seed;
146
147 @Override
148 public double sample() {
149 return x += 1;
150 }
151 };
152 }
153 }