1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.statistics.distribution;
18
19 import java.math.BigDecimal;
20 import java.math.BigInteger;
21 import org.junit.jupiter.api.Assertions;
22 import org.junit.jupiter.api.Test;
23 import org.junit.jupiter.params.ParameterizedTest;
24 import org.junit.jupiter.params.provider.CsvSource;
25 import org.junit.jupiter.params.provider.ValueSource;
26
27
28
29
30
31 class BinomialDistributionTest extends BaseDiscreteDistributionTest {
32 @Override
33 DiscreteDistribution makeDistribution(Object... parameters) {
34 final int n = (Integer) parameters[0];
35 final double p = (Double) parameters[1];
36 return BinomialDistribution.of(n, p);
37 }
38
39
40 @Override
41 Object[][] makeInvalidParameters() {
42 return new Object[][] {
43 {-1, 0.1},
44 {10, -0.1},
45 {10, 1.1},
46 };
47 }
48
49 @Override
50 String[] getParameterNames() {
51 return new String[] {"NumberOfTrials", "ProbabilityOfSuccess"};
52 }
53
54 @Override
55 protected double getRelativeTolerance() {
56
57 return 4 * RELATIVE_EPS;
58 }
59
60
61
62 @Test
63 void testMath718() {
64
65
66
67
68 for (int trials = 500000; trials < 20000000; trials += 100000) {
69 final BinomialDistribution dist = BinomialDistribution.of(trials, 0.5);
70 final int p = dist.inverseCumulativeProbability(0.5);
71 Assertions.assertEquals(trials / 2, p);
72 }
73 }
74
75
76
77
78 @ParameterizedTest
79 @ValueSource(ints = {0, 1, 2, 3, 10})
80 void testProbabilityOfSuccess0(int n) {
81
82
83 for (final double p : new double[] {-0.0, 0.0}) {
84 final BinomialDistribution dist = BinomialDistribution.of(n, p);
85 for (int k = -1; k <= n + 1; k++) {
86 Assertions.assertEquals(k == 0 ? 1.0 : 0.0, dist.probability(k));
87 Assertions.assertEquals(k == 0 ? 0.0 : Double.NEGATIVE_INFINITY, dist.logProbability(k));
88 Assertions.assertEquals(k >= 0 ? 1.0 : 0.0, dist.cumulativeProbability(k));
89 Assertions.assertEquals(k >= 0 ? 0.0 : 1.0, dist.survivalProbability(k));
90 }
91 }
92 }
93
94
95
96
97 @ParameterizedTest
98 @ValueSource(ints = {0, 1, 2, 3, 10})
99 void testProbabilityOfSuccess1(int n) {
100 final BinomialDistribution dist = BinomialDistribution.of(n, 1);
101
102 for (int k = -1; k <= n + 1; k++) {
103 Assertions.assertEquals(k == n ? 1.0 : 0.0, dist.probability(k));
104 Assertions.assertEquals(k == n ? 0.0 : Double.NEGATIVE_INFINITY, dist.logProbability(k));
105 Assertions.assertEquals(k >= n ? 1.0 : 0.0, dist.cumulativeProbability(k));
106 Assertions.assertEquals(k >= n ? 0.0 : 1.0, dist.survivalProbability(k));
107 }
108 }
109
110 @ParameterizedTest
111 @ValueSource(doubles = {0, 1, 0.01, 0.99, 1e-17, 0.3645257e-8, 0.123415276368128, 0.67834532657232434})
112 void testNumberOfTrials0(double p) {
113 final BinomialDistribution dist = BinomialDistribution.of(0, p);
114
115 for (int k = -1; k <= 2; k++) {
116 Assertions.assertEquals(k == 0 ? 1.0 : 0.0, dist.probability(k));
117 Assertions.assertEquals(k == 0 ? 0.0 : Double.NEGATIVE_INFINITY, dist.logProbability(k));
118 Assertions.assertEquals(k >= 0 ? 1.0 : 0.0, dist.cumulativeProbability(k));
119 Assertions.assertEquals(k >= 0 ? 0.0 : 1.0, dist.survivalProbability(k));
120 }
121 }
122
123 @ParameterizedTest
124 @ValueSource(doubles = {0, 1, 0.01, 0.99, 1e-17, 0.3645257e-8, 0.123415276368128, 0.67834532657232434})
125 void testNumberOfTrials1(double p) {
126 final BinomialDistribution dist = BinomialDistribution.of(1, p);
127
128 Assertions.assertEquals(0.0, dist.probability(-1));
129 Assertions.assertEquals(1 - p, dist.probability(0));
130 Assertions.assertEquals(p, dist.probability(1));
131 Assertions.assertEquals(0.0, dist.probability(2));
132 Assertions.assertEquals(Double.NEGATIVE_INFINITY, dist.logProbability(-1));
133
134 TestUtils.assertEquals(Math.log1p(-p), dist.logProbability(0), DoubleTolerances.ulps(1));
135 Assertions.assertEquals(Math.log(p), dist.logProbability(1));
136 Assertions.assertEquals(Double.NEGATIVE_INFINITY, dist.logProbability(2));
137 Assertions.assertEquals(0.0, dist.cumulativeProbability(-1));
138 Assertions.assertEquals(1 - p, dist.cumulativeProbability(0));
139 Assertions.assertEquals(1.0, dist.cumulativeProbability(1));
140 Assertions.assertEquals(1.0, dist.cumulativeProbability(2));
141 Assertions.assertEquals(1.0, dist.survivalProbability(-1));
142 Assertions.assertEquals(p, dist.survivalProbability(0));
143 Assertions.assertEquals(0.0, dist.survivalProbability(1));
144 Assertions.assertEquals(0.0, dist.survivalProbability(2));
145 }
146
147
148
149
150
151
152
153
154
155 @ParameterizedTest
156 @ValueSource(doubles = {0, 1, 0.01, 0.99, 1e-17, 0.3645257e-8, 0.123415276368128, 0.67834532657232434})
157 void testX0(double p) {
158 for (final int n : new int[] {0, 1, 10}) {
159 final BinomialDistribution dist = BinomialDistribution.of(n, p);
160 final double expected = SaddlePointExpansionUtils.logBinomialProbability(0, n, p, 1 - p);
161 Assertions.assertEquals(expected, dist.logProbability(0));
162 }
163 }
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180 @ParameterizedTest
181 @CsvSource({
182
183 "100, 0.50, 8e-15, 8e-15",
184 "100, 0.01, 1e-15, 3e-14",
185 "100, 0.99, 4e-15, 1e-15",
186 "140, 0.01, 1e-15, 2e-13",
187 "140, 0.99, 2e-13, 1e-15",
188 "50, 0.001, 1e-15, 5e-14",
189 "50, 0.999, 3e-14, 1e-15",
190 })
191 void testBounds(int n, double p, double eps1, double epsn1) {
192 final BinomialDistribution dist = BinomialDistribution.of(n, p);
193 final BigDecimal prob0 = binomialProbability(n, p, 0);
194 final BigDecimal probn = binomialProbability(n, p, n);
195 final double p0 = prob0.doubleValue();
196 final double pn = probn.doubleValue();
197
198
199
200 final double minp = Math.min(p0, pn);
201 Assertions.assertTrue(minp < 0x1.0p-53, () -> "Test should target small p-values: " + minp);
202 Assertions.assertTrue(minp > Double.MIN_NORMAL, () -> "Minimum P-value should not be sub-normal: " + minp);
203
204
205 final DoubleTolerance tol1 = DoubleTolerances.ulps(1);
206 TestUtils.assertEquals(p0, dist.probability(0), tol1, "pmf(0)");
207 TestUtils.assertEquals(pn, dist.probability(n), tol1, "pmf(n)");
208
209 Assertions.assertEquals(dist.probability(0), dist.cumulativeProbability(0), "pmf(0) != cdf(0)");
210 Assertions.assertEquals(dist.probability(n), dist.survivalProbability(n - 1), "pmf(n) != sf(n-1)");
211
212
213
214 if (p0 < 0.9) {
215 TestUtils.assertEquals(Math.log(p0), dist.logProbability(0), tol1, "log(pmf(0)) != logpmf(0)");
216 } else {
217 TestUtils.assertEquals(p0, Math.exp(dist.logProbability(0)), tol1, "pmf(0) != exp(logpmf(0))");
218 }
219 if (pn < 0.9) {
220 TestUtils.assertEquals(Math.log(pn), dist.logProbability(n), tol1, "log(pmf(n)) != logpmf(n)");
221 } else {
222 TestUtils.assertEquals(pn, Math.exp(dist.logProbability(n)), tol1, "pmf(n) != exp(logpmf(n))");
223 }
224
225
226 final BigDecimal prob1 = binomialProbability(n, p, 1);
227 final BigDecimal probn1 = binomialProbability(n, p, n - 1);
228 TestUtils.assertEquals(prob1.doubleValue(), dist.probability(1), createRelTolerance(eps1), "pmf(1)");
229 TestUtils.assertEquals(probn1.doubleValue(), dist.probability(n - 1), createRelTolerance(epsn1), "pmf(n-1)");
230
231
232 final double cdf1 = prob0.add(prob1).doubleValue();
233 final double sfn2 = probn.add(probn1).doubleValue();
234 TestUtils.assertEquals(cdf1, dist.cumulativeProbability(1), createRelTolerance(eps1), "cmf(1)");
235 TestUtils.assertEquals(sfn2, dist.survivalProbability(n - 2), createRelTolerance(epsn1), "sf(n-2)");
236 }
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255 private static BigDecimal binomialProbability(int n, double p, int k) {
256 final int nmk = n - k;
257 final int m = Math.min(k, nmk);
258
259 final BigDecimal bp = new BigDecimal(p);
260 final BigDecimal result = bp.pow(k).multiply(
261 BigDecimal.ONE.subtract(bp).pow(nmk));
262
263
264 if (m == 0) {
265 return result;
266 } else if (m == 1) {
267 return result.multiply(new BigDecimal(n));
268 }
269
270 BigInteger nCk = BigInteger.ONE;
271 int i = n - m + 1;
272 for (int j = 1; j <= m; j++) {
273 nCk = nCk.multiply(BigInteger.valueOf(i)).divide(BigInteger.valueOf(j));
274 i++;
275 }
276 return new BigDecimal(nCk).multiply(result);
277 }
278 }