View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.statistics.distribution;
18  
19  import java.math.BigDecimal;
20  import java.math.BigInteger;
21  import org.junit.jupiter.api.Assertions;
22  import org.junit.jupiter.api.Test;
23  import org.junit.jupiter.params.ParameterizedTest;
24  import org.junit.jupiter.params.provider.CsvSource;
25  import org.junit.jupiter.params.provider.ValueSource;
26  
27  /**
28   * Test cases for {@link BinomialDistribution}.
29   * Extends {@link BaseDiscreteDistributionTest}. See javadoc of that class for details.
30   */
31  class BinomialDistributionTest extends BaseDiscreteDistributionTest {
32      @Override
33      DiscreteDistribution makeDistribution(Object... parameters) {
34          final int n = (Integer) parameters[0];
35          final double p = (Double) parameters[1];
36          return BinomialDistribution.of(n, p);
37      }
38  
39  
40      @Override
41      Object[][] makeInvalidParameters() {
42          return new Object[][] {
43              {-1, 0.1},
44              {10, -0.1},
45              {10, 1.1},
46          };
47      }
48  
49      @Override
50      String[] getParameterNames() {
51          return new String[] {"NumberOfTrials", "ProbabilityOfSuccess"};
52      }
53  
54      @Override
55      protected double getRelativeTolerance() {
56          // Tolerance is 8.881784197001252E-16
57          return 4 * RELATIVE_EPS;
58      }
59  
60      //-------------------- Additional test cases -------------------------------
61  
62      @Test
63      void testMath718() {
64          // For large trials the evaluation of ContinuedFraction was inaccurate.
65          // Do a sweep over several large trials to test if the current implementation is
66          // numerically stable.
67  
68          for (int trials = 500000; trials < 20000000; trials += 100000) {
69              final BinomialDistribution dist = BinomialDistribution.of(trials, 0.5);
70              final int p = dist.inverseCumulativeProbability(0.5);
71              Assertions.assertEquals(trials / 2, p);
72          }
73      }
74  
75      /**
76       * Test special case of probability of success 0.0.
77       */
78      @ParameterizedTest
79      @ValueSource(ints = {0, 1, 2, 3, 10})
80      void testProbabilityOfSuccess0(int n) {
81          // The sign of p should not matter.
82          // Exact equality checks no -0.0 values are generated.
83          for (final double p : new double[] {-0.0, 0.0}) {
84              final BinomialDistribution dist = BinomialDistribution.of(n, p);
85              for (int k = -1; k <= n + 1; k++) {
86                  Assertions.assertEquals(k == 0 ? 1.0 : 0.0, dist.probability(k));
87                  Assertions.assertEquals(k == 0 ? 0.0 : Double.NEGATIVE_INFINITY, dist.logProbability(k));
88                  Assertions.assertEquals(k >= 0 ? 1.0 : 0.0, dist.cumulativeProbability(k));
89                  Assertions.assertEquals(k >= 0 ? 0.0 : 1.0, dist.survivalProbability(k));
90              }
91          }
92      }
93  
94      /**
95       * Test special case of probability of success 1.0.
96       */
97      @ParameterizedTest
98      @ValueSource(ints = {0, 1, 2, 3, 10})
99      void testProbabilityOfSuccess1(int n) {
100         final BinomialDistribution dist = BinomialDistribution.of(n, 1);
101         // Exact equality checks no -0.0 values are generated.
102         for (int k = -1; k <= n + 1; k++) {
103             Assertions.assertEquals(k == n ? 1.0 : 0.0, dist.probability(k));
104             Assertions.assertEquals(k == n ? 0.0 : Double.NEGATIVE_INFINITY, dist.logProbability(k));
105             Assertions.assertEquals(k >= n ? 1.0 : 0.0, dist.cumulativeProbability(k));
106             Assertions.assertEquals(k >= n ? 0.0 : 1.0, dist.survivalProbability(k));
107         }
108     }
109 
110     @ParameterizedTest
111     @ValueSource(doubles = {0, 1, 0.01, 0.99, 1e-17, 0.3645257e-8, 0.123415276368128, 0.67834532657232434})
112     void testNumberOfTrials0(double p) {
113         final BinomialDistribution dist = BinomialDistribution.of(0, p);
114         // Edge case where the probability is ignored when computing the result
115         for (int k = -1; k <= 2; k++) {
116             Assertions.assertEquals(k == 0 ? 1.0 : 0.0, dist.probability(k));
117             Assertions.assertEquals(k == 0 ? 0.0 : Double.NEGATIVE_INFINITY, dist.logProbability(k));
118             Assertions.assertEquals(k >= 0 ? 1.0 : 0.0, dist.cumulativeProbability(k));
119             Assertions.assertEquals(k >= 0 ? 0.0 : 1.0, dist.survivalProbability(k));
120         }
121     }
122 
123     @ParameterizedTest
124     @ValueSource(doubles = {0, 1, 0.01, 0.99, 1e-17, 0.3645257e-8, 0.123415276368128, 0.67834532657232434})
125     void testNumberOfTrials1(double p) {
126         final BinomialDistribution dist = BinomialDistribution.of(1, p);
127         // Edge case where the probability should be exact
128         Assertions.assertEquals(0.0, dist.probability(-1));
129         Assertions.assertEquals(1 - p, dist.probability(0));
130         Assertions.assertEquals(p, dist.probability(1));
131         Assertions.assertEquals(0.0, dist.probability(2));
132         Assertions.assertEquals(Double.NEGATIVE_INFINITY, dist.logProbability(-1));
133         // Current implementation does not use log1p so allow an error tolerance
134         TestUtils.assertEquals(Math.log1p(-p), dist.logProbability(0), DoubleTolerances.ulps(1));
135         Assertions.assertEquals(Math.log(p), dist.logProbability(1));
136         Assertions.assertEquals(Double.NEGATIVE_INFINITY, dist.logProbability(2));
137         Assertions.assertEquals(0.0, dist.cumulativeProbability(-1));
138         Assertions.assertEquals(1 - p, dist.cumulativeProbability(0));
139         Assertions.assertEquals(1.0, dist.cumulativeProbability(1));
140         Assertions.assertEquals(1.0, dist.cumulativeProbability(2));
141         Assertions.assertEquals(1.0, dist.survivalProbability(-1));
142         Assertions.assertEquals(p, dist.survivalProbability(0));
143         Assertions.assertEquals(0.0, dist.survivalProbability(1));
144         Assertions.assertEquals(0.0, dist.survivalProbability(2));
145     }
146 
147     /**
148      * Special case for x=0.
149      * This hits cases where the SaddlePointExpansionUtils are not used for
150      * probability functions. It ensures the edge case handling in BinomialDistribution
151      * matches the original logic in the saddle point expansion. This x=0 logic is used
152      * by the related hypergeometric distribution and covered by test cases for that
153      * distribution ensuring it is correct.
154      */
155     @ParameterizedTest
156     @ValueSource(doubles = {0, 1, 0.01, 0.99, 1e-17, 0.3645257e-8, 0.123415276368128, 0.67834532657232434})
157     void testX0(double p) {
158         for (final int n : new int[] {0, 1, 10}) {
159             final BinomialDistribution dist = BinomialDistribution.of(n, p);
160             final double expected = SaddlePointExpansionUtils.logBinomialProbability(0, n, p, 1 - p);
161             Assertions.assertEquals(expected, dist.logProbability(0));
162         }
163     }
164 
165     /**
166      * Test the probability functions at the lower and upper bounds when the p-values
167      * are very small. The expected results should be within 1 ULP of an exact
168      * computation for pmf(x=0) and pmf(x=n). These values can be computed using
169      * java.lang.Math functions.
170      *
171      * <p>The next value, e.g. pmf(x=1) and cdf(x=1), is asserted to the specified
172      * relative error tolerance. These values require computations using p and 1-p
173      * and are less exact.
174      *
175      * @param n Number of trials
176      * @param p Probability of success
177      * @param eps1 Relative error tolerance for pmf(x=1)
178      * @param epsn1 Relative error tolerance for pmf(x=n-1)
179      */
180     @ParameterizedTest
181     @CsvSource({
182         // Min p-value is shown for reference.
183         "100, 0.50, 8e-15, 8e-15", // 7.888609052210118E-31
184         "100, 0.01, 1e-15, 3e-14", // 1.000000000000002E-200
185         "100, 0.99, 4e-15, 1e-15", // 1.0000000000000887E-200
186         "140, 0.01, 1e-15, 2e-13", // 1.0000000000000029E-280
187         "140, 0.99, 2e-13, 1e-15", // 1.0000000000001244E-280
188         "50, 0.001, 1e-15, 5e-14", // 1.0000000000000011E-150
189         "50, 0.999, 3e-14, 1e-15", // 1.0000000000000444E-150
190     })
191     void testBounds(int n, double p, double eps1, double epsn1) {
192         final BinomialDistribution dist = BinomialDistribution.of(n, p);
193         final BigDecimal prob0 = binomialProbability(n, p, 0);
194         final BigDecimal probn = binomialProbability(n, p, n);
195         final double p0 = prob0.doubleValue();
196         final double pn = probn.doubleValue();
197 
198         // Require very small non-zero probabilities to make the test difficult.
199         // Check using 2^-53 so that at least one p-value satisfies 1 - p == 1.
200         final double minp = Math.min(p0, pn);
201         Assertions.assertTrue(minp < 0x1.0p-53, () -> "Test should target small p-values: " + minp);
202         Assertions.assertTrue(minp > Double.MIN_NORMAL, () -> "Minimum P-value should not be sub-normal: " + minp);
203 
204         // Almost exact at the bounds
205         final DoubleTolerance tol1 = DoubleTolerances.ulps(1);
206         TestUtils.assertEquals(p0, dist.probability(0), tol1, "pmf(0)");
207         TestUtils.assertEquals(pn, dist.probability(n), tol1, "pmf(n)");
208         // Consistent at the bounds
209         Assertions.assertEquals(dist.probability(0), dist.cumulativeProbability(0), "pmf(0) != cdf(0)");
210         Assertions.assertEquals(dist.probability(n), dist.survivalProbability(n - 1), "pmf(n) != sf(n-1)");
211 
212         // Test probability and log probability are consistent.
213         // Avoid log when p-value is close to 1.
214         if (p0 < 0.9) {
215             TestUtils.assertEquals(Math.log(p0), dist.logProbability(0), tol1, "log(pmf(0)) != logpmf(0)");
216         } else {
217             TestUtils.assertEquals(p0, Math.exp(dist.logProbability(0)), tol1, "pmf(0) != exp(logpmf(0))");
218         }
219         if (pn < 0.9) {
220             TestUtils.assertEquals(Math.log(pn), dist.logProbability(n), tol1, "log(pmf(n)) != logpmf(n)");
221         } else {
222             TestUtils.assertEquals(pn, Math.exp(dist.logProbability(n)), tol1, "pmf(n) != exp(logpmf(n))");
223         }
224 
225         // The next probability is accurate to the specified tolerance.
226         final BigDecimal prob1 = binomialProbability(n, p, 1);
227         final BigDecimal probn1 = binomialProbability(n, p, n - 1);
228         TestUtils.assertEquals(prob1.doubleValue(), dist.probability(1), createRelTolerance(eps1), "pmf(1)");
229         TestUtils.assertEquals(probn1.doubleValue(), dist.probability(n - 1), createRelTolerance(epsn1), "pmf(n-1)");
230 
231         // Check the cumulative functions
232         final double cdf1 = prob0.add(prob1).doubleValue();
233         final double sfn2 = probn.add(probn1).doubleValue();
234         TestUtils.assertEquals(cdf1, dist.cumulativeProbability(1), createRelTolerance(eps1), "cmf(1)");
235         TestUtils.assertEquals(sfn2, dist.survivalProbability(n - 2), createRelTolerance(epsn1), "sf(n-2)");
236     }
237 
238     /**
239      * Compute the binomial distribution probability mass function using exact
240      * arithmetic.
241      *
242      * <p>This has no error handling for invalid arguments.
243      *
244      * <p>Warning: BigDecimal has a limit on the size of the exponent for the power
245      * function. This method has not been extensively tested with very small
246      * p-values, large n or large k. Use of a MathContext to round intermediates may be
247      * required to reduce memory consumption. The binomial coefficient may not
248      * compute for large n and k ~ n/2.
249      *
250      * @param n Number of trials (must be positive)
251      * @param p Probability of success (in [0, 1])
252      * @param k Number of successes (must be positive)
253      * @return pmf(X=k)
254      */
255     private static BigDecimal binomialProbability(int n, double p, int k) {
256         final int nmk = n - k;
257         final int m = Math.min(k, nmk);
258         // Probability component: p^k * (1-p)^(n-k)
259         final BigDecimal bp = new BigDecimal(p);
260         final BigDecimal result = bp.pow(k).multiply(
261                  BigDecimal.ONE.subtract(bp).pow(nmk));
262         // Compute the binomial coefficient
263         // Simple edge cases first.
264         if (m == 0) {
265             return result;
266         } else if (m == 1) {
267             return result.multiply(new BigDecimal(n));
268         }
269         // See org.apache.commons.numbers.combinatorics.BinomialCoefficient
270         BigInteger nCk = BigInteger.ONE;
271         int i = n - m + 1;
272         for (int j = 1; j <= m; j++) {
273             nCk = nCk.multiply(BigInteger.valueOf(i)).divide(BigInteger.valueOf(j));
274             i++;
275         }
276         return new BigDecimal(nCk).multiply(result);
277     }
278 }