View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.rng.sampling.distribution;
18  
19  import java.util.ArrayList;
20  import java.util.Collections;
21  import java.util.List;
22  import org.apache.commons.math3.util.MathArrays;
23  import org.apache.commons.rng.UniformRandomProvider;
24  import org.apache.commons.rng.sampling.RandomAssert;
25  
26  /**
27   * List of samplers.
28   */
29  public final class DiscreteSamplersList {
30      /** List of all RNGs implemented in the library. */
31      private static final List<DiscreteSamplerTestData> LIST = new ArrayList<>();
32  
33      static {
34          try {
35              // This test uses reference distributions from commons-math3 to compute the expected
36              // PMF. These distributions have a dual functionality to compute the PMF and perform
37              // sampling. When no sampling is needed for the created distribution, it is advised
38              // to pass null as the random generator via the appropriate constructors to avoid the
39              // additional initialisation overhead.
40              org.apache.commons.math3.random.RandomGenerator unusedRng = null;
41  
42              // List of distributions to test.
43  
44              // Binomial ("inverse method").
45              final int trialsBinomial = 20;
46              final double probSuccessBinomial = 0.67;
47              add(LIST, new org.apache.commons.math3.distribution.BinomialDistribution(unusedRng, trialsBinomial, probSuccessBinomial),
48                  MathArrays.sequence(8, 9, 1),
49                  RandomAssert.createRNG());
50              add(LIST, new org.apache.commons.math3.distribution.BinomialDistribution(unusedRng, trialsBinomial, probSuccessBinomial),
51                  // range [9,16]
52                  MathArrays.sequence(8, 9, 1),
53                  MarsagliaTsangWangDiscreteSampler.Binomial.of(RandomAssert.createRNG(), trialsBinomial, probSuccessBinomial));
54              // Inverted
55              add(LIST, new org.apache.commons.math3.distribution.BinomialDistribution(unusedRng, trialsBinomial, 1 - probSuccessBinomial),
56                  // range [4,11] = [20-16, 20-9]
57                  MathArrays.sequence(8, 4, 1),
58                  MarsagliaTsangWangDiscreteSampler.Binomial.of(RandomAssert.createRNG(), trialsBinomial, 1 - probSuccessBinomial));
59  
60              // Geometric ("inverse method").
61              final double probSuccessGeometric = 0.21;
62              add(LIST, new org.apache.commons.math3.distribution.GeometricDistribution(unusedRng, probSuccessGeometric),
63                  MathArrays.sequence(10, 0, 1),
64                  RandomAssert.createRNG());
65              // Geometric.
66              add(LIST, new org.apache.commons.math3.distribution.GeometricDistribution(unusedRng, probSuccessGeometric),
67                  MathArrays.sequence(10, 0, 1),
68                  GeometricSampler.of(RandomAssert.createRNG(), probSuccessGeometric));
69  
70              // Hypergeometric ("inverse method").
71              final int popSizeHyper = 34;
72              final int numSuccessesHyper = 11;
73              final int sampleSizeHyper = 12;
74              add(LIST, new org.apache.commons.math3.distribution.HypergeometricDistribution(unusedRng, popSizeHyper, numSuccessesHyper, sampleSizeHyper),
75                  MathArrays.sequence(10, 0, 1),
76                  RandomAssert.createRNG());
77  
78              // Pascal ("inverse method").
79              final int numSuccessesPascal = 6;
80              final double probSuccessPascal = 0.2;
81              add(LIST, new org.apache.commons.math3.distribution.PascalDistribution(unusedRng, numSuccessesPascal, probSuccessPascal),
82                  MathArrays.sequence(18, 1, 1),
83                  RandomAssert.createRNG());
84  
85              // Uniform ("inverse method").
86              final int loUniform = -3;
87              final int hiUniform = 4;
88              add(LIST, new org.apache.commons.math3.distribution.UniformIntegerDistribution(unusedRng, loUniform, hiUniform),
89                  MathArrays.sequence(8, -3, 1),
90                  RandomAssert.createRNG());
91              // Uniform (power of 2 range).
92              add(LIST, new org.apache.commons.math3.distribution.UniformIntegerDistribution(unusedRng, loUniform, hiUniform),
93                  MathArrays.sequence(8, -3, 1),
94                  DiscreteUniformSampler.of(RandomAssert.createRNG(), loUniform, hiUniform));
95              // Uniform (large range).
96              final int halfMax = Integer.MAX_VALUE / 2;
97              final int hiLargeUniform = halfMax + 10;
98              final int loLargeUniform = -hiLargeUniform;
99              add(LIST, new org.apache.commons.math3.distribution.UniformIntegerDistribution(unusedRng, loLargeUniform, hiLargeUniform),
100                 MathArrays.sequence(20, -halfMax, halfMax / 10),
101                 DiscreteUniformSampler.of(RandomAssert.createRNG(), loLargeUniform, hiLargeUniform));
102             // Uniform (non-power of 2 range).
103             final int rangeNonPowerOf2Uniform = 11;
104             final int hiNonPowerOf2Uniform = loUniform + rangeNonPowerOf2Uniform;
105             add(LIST, new org.apache.commons.math3.distribution.UniformIntegerDistribution(unusedRng, loUniform, hiNonPowerOf2Uniform),
106                 MathArrays.sequence(rangeNonPowerOf2Uniform, -3, 1),
107                 DiscreteUniformSampler.of(RandomAssert.createRNG(), loUniform, hiNonPowerOf2Uniform));
108 
109             // Zipf ("inverse method").
110             final int numElementsZipf = 5;
111             final double exponentZipf = 2.345;
112             add(LIST, new org.apache.commons.math3.distribution.ZipfDistribution(unusedRng, numElementsZipf, exponentZipf),
113                 MathArrays.sequence(5, 1, 1),
114                 RandomAssert.createRNG());
115             // Zipf.
116             add(LIST, new org.apache.commons.math3.distribution.ZipfDistribution(unusedRng, numElementsZipf, exponentZipf),
117                 MathArrays.sequence(5, 1, 1),
118                 RejectionInversionZipfSampler.of(RandomAssert.createRNG(), numElementsZipf, exponentZipf));
119             // Zipf (exponent close to 1).
120             final double exponentCloseToOneZipf = 1 - 1e-10;
121             add(LIST, new org.apache.commons.math3.distribution.ZipfDistribution(unusedRng, numElementsZipf, exponentCloseToOneZipf),
122                 MathArrays.sequence(5, 1, 1),
123                 RejectionInversionZipfSampler.of(RandomAssert.createRNG(), numElementsZipf, exponentCloseToOneZipf));
124             // Zipf (exponent = 0).
125             add(LIST, MathArrays.sequence(5, 1, 1), new double[] {0.2, 0.2, 0.2, 0.2, 0.2},
126                 RejectionInversionZipfSampler.of(RandomAssert.createRNG(), numElementsZipf, 0.0));
127 
128             // Poisson ("inverse method").
129             final double epsilonPoisson = org.apache.commons.math3.distribution.PoissonDistribution.DEFAULT_EPSILON;
130             final int maxIterationsPoisson = org.apache.commons.math3.distribution.PoissonDistribution.DEFAULT_MAX_ITERATIONS;
131             final double meanPoisson = 3.21;
132             add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, meanPoisson, epsilonPoisson, maxIterationsPoisson),
133                 MathArrays.sequence(10, 0, 1),
134                 RandomAssert.createRNG());
135             // Poisson.
136             add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, meanPoisson, epsilonPoisson, maxIterationsPoisson),
137                 MathArrays.sequence(10, 0, 1),
138                 PoissonSampler.of(RandomAssert.createRNG(), meanPoisson));
139             // Dedicated small mean poisson samplers
140             add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, meanPoisson, epsilonPoisson, maxIterationsPoisson),
141                 MathArrays.sequence(10, 0, 1),
142                 SmallMeanPoissonSampler.of(RandomAssert.createRNG(), meanPoisson));
143             add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, meanPoisson, epsilonPoisson, maxIterationsPoisson),
144                 MathArrays.sequence(10, 0, 1),
145                 KempSmallMeanPoissonSampler.of(RandomAssert.createRNG(), meanPoisson));
146             add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, meanPoisson, epsilonPoisson, maxIterationsPoisson),
147                 MathArrays.sequence(10, 0, 1),
148                 MarsagliaTsangWangDiscreteSampler.Poisson.of(RandomAssert.createRNG(), meanPoisson));
149             // LargeMeanPoissonSampler should work at small mean.
150             // Note: This hits a code path where the sample from the normal distribution is rejected.
151             add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, meanPoisson, epsilonPoisson, maxIterationsPoisson),
152                 MathArrays.sequence(10, 0, 1),
153                 LargeMeanPoissonSampler.of(RandomAssert.createRNG(), meanPoisson));
154             // Poisson (40 < mean < 80).
155             final double largeMeanPoisson = 67.89;
156             add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, largeMeanPoisson, epsilonPoisson, maxIterationsPoisson),
157                 MathArrays.sequence(50, (int) (largeMeanPoisson - 25), 1),
158                 PoissonSampler.of(RandomAssert.createRNG(), largeMeanPoisson));
159             // Dedicated large mean poisson sampler
160             add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, largeMeanPoisson, epsilonPoisson, maxIterationsPoisson),
161                 MathArrays.sequence(50, (int) (largeMeanPoisson - 25), 1),
162                 LargeMeanPoissonSampler.of(RandomAssert.createRNG(), largeMeanPoisson));
163             add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, largeMeanPoisson, epsilonPoisson, maxIterationsPoisson),
164                 MathArrays.sequence(50, (int) (largeMeanPoisson - 25), 1),
165                 MarsagliaTsangWangDiscreteSampler.Poisson.of(RandomAssert.createRNG(), largeMeanPoisson));
166             // Poisson (mean >> 40).
167             final double veryLargeMeanPoisson = 543.21;
168             add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, veryLargeMeanPoisson, epsilonPoisson, maxIterationsPoisson),
169                 MathArrays.sequence(100, (int) (veryLargeMeanPoisson - 50), 1),
170                 PoissonSampler.of(RandomAssert.createRNG(), veryLargeMeanPoisson));
171             // Dedicated large mean poisson sampler
172             add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, veryLargeMeanPoisson, epsilonPoisson, maxIterationsPoisson),
173                 MathArrays.sequence(100, (int) (veryLargeMeanPoisson - 50), 1),
174                 LargeMeanPoissonSampler.of(RandomAssert.createRNG(), veryLargeMeanPoisson));
175             add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, veryLargeMeanPoisson, epsilonPoisson, maxIterationsPoisson),
176                 MathArrays.sequence(100, (int) (veryLargeMeanPoisson - 50), 1),
177                 MarsagliaTsangWangDiscreteSampler.Poisson.of(RandomAssert.createRNG(), veryLargeMeanPoisson));
178 
179             // Any discrete distribution
180             final int[] discretePoints = {0, 1, 2, 3, 4};
181             final double[] discreteProbabilities = {0.1, 0.2, 0.3, 0.4, 0.5};
182             final long[] discreteFrequencies = {1, 2, 3, 4, 5};
183             add(LIST, discretePoints, discreteProbabilities,
184                 MarsagliaTsangWangDiscreteSampler.Enumerated.of(RandomAssert.createRNG(), discreteProbabilities));
185             add(LIST, discretePoints, discreteProbabilities,
186                 GuideTableDiscreteSampler.of(RandomAssert.createRNG(), discreteProbabilities));
187             add(LIST, discretePoints, discreteProbabilities,
188                 AliasMethodDiscreteSampler.of(RandomAssert.createRNG(), discreteProbabilities));
189             add(LIST, discretePoints, discreteProbabilities,
190                 FastLoadedDiceRollerDiscreteSampler.of(RandomAssert.createRNG(), discreteFrequencies));
191             add(LIST, discretePoints, discreteProbabilities,
192                 FastLoadedDiceRollerDiscreteSampler.of(RandomAssert.createRNG(), discreteProbabilities));
193         } catch (Exception e) {
194             // CHECKSTYLE: stop Regexp
195             System.err.println("Unexpected exception while creating the list of samplers: " + e);
196             e.printStackTrace(System.err);
197             // CHECKSTYLE: resume Regexp
198             throw new RuntimeException(e);
199         }
200     }
201 
202     /**
203      * Class contains only static methods.
204      */
205     private DiscreteSamplersList() {}
206 
207     /**
208      * @param list List of data (one the "parameters" tested by the Junit parametric test).
209      * @param dist Distribution to which the samples are supposed to conform.
210      * @param points Outcomes selection.
211      * @param rng Generator of uniformly distributed sequences.
212      */
213     private static void add(List<DiscreteSamplerTestData> list,
214                             final org.apache.commons.math3.distribution.IntegerDistribution dist,
215                             int[] points,
216                             UniformRandomProvider rng) {
217         final DiscreteSampler inverseMethodSampler =
218             InverseTransformDiscreteSampler.of(rng,
219                 new DiscreteInverseCumulativeProbabilityFunction() {
220                     @Override
221                     public int inverseCumulativeProbability(double p) {
222                         return dist.inverseCumulativeProbability(p);
223                     }
224                     @Override
225                     public String toString() {
226                         return dist.toString();
227                     }
228                 });
229         list.add(new DiscreteSamplerTestData(inverseMethodSampler,
230                                              points,
231                                              getProbabilities(dist, points)));
232     }
233 
234     /**
235      * @param list List of data (one the "parameters" tested by the Junit parametric test).
236      * @param dist Distribution to which the samples are supposed to conform.
237      * @param points Outcomes selection.
238      * @param sampler Sampler.
239      */
240     private static void add(List<DiscreteSamplerTestData> list,
241                             final org.apache.commons.math3.distribution.IntegerDistribution dist,
242                             int[] points,
243                             final DiscreteSampler sampler) {
244         list.add(new DiscreteSamplerTestData(sampler,
245                                              points,
246                                              getProbabilities(dist, points)));
247     }
248 
249     /**
250      * @param list List of data (one the "parameters" tested by the Junit parametric test).
251      * @param points Outcomes selection.
252      * @param probabilities Probability distribution to which the samples are supposed to conform.
253      * @param sampler Sampler.
254      */
255     private static void add(List<DiscreteSamplerTestData> list,
256                             int[] points,
257                             final double[] probabilities,
258                             final DiscreteSampler sampler) {
259         list.add(new DiscreteSamplerTestData(sampler,
260                                              points,
261                                              probabilities));
262     }
263 
264     /**
265      * Subclasses that are "parametric" tests can forward the call to
266      * the "@Parameters"-annotated method to this method.
267      *
268      * @return the list of all generators.
269      */
270     public static Iterable<DiscreteSamplerTestData> list() {
271         return Collections.unmodifiableList(LIST);
272     }
273 
274     /**
275      * @param dist Distribution.
276      * @param points Points.
277      * @return the probabilities of the given points according to the distribution.
278      */
279     private static double[] getProbabilities(org.apache.commons.math3.distribution.IntegerDistribution dist,
280                                              int[] points) {
281         final int len = points.length;
282         final double[] prob = new double[len];
283         for (int i = 0; i < len; i++) {
284             prob[i] = dist instanceof org.apache.commons.math3.distribution.UniformIntegerDistribution ? // XXX Workaround.
285                 getProbability((org.apache.commons.math3.distribution.UniformIntegerDistribution) dist) :
286                 dist.probability(points[i]);
287 
288             if (prob[i] < 0) {
289                 throw new IllegalStateException(dist + ": p < 0 (at " + points[i] + ", p=" + prob[i]);
290             }
291         }
292         return prob;
293     }
294 
295     /**
296      * Workaround bugs in Commons Math's "UniformIntegerDistribution" (cf. MATH-1396).
297      */
298     private static double getProbability(org.apache.commons.math3.distribution.UniformIntegerDistribution dist) {
299         return 1 / ((double) dist.getSupportUpperBound() - (double) dist.getSupportLowerBound() + 1);
300     }
301 }