1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.sampling.distribution;
18
19 import java.util.ArrayList;
20 import java.util.Collections;
21 import java.util.List;
22 import org.apache.commons.math3.util.MathArrays;
23 import org.apache.commons.rng.UniformRandomProvider;
24 import org.apache.commons.rng.sampling.RandomAssert;
25
26
27
28
29 public final class DiscreteSamplersList {
30
31 private static final List<DiscreteSamplerTestData> LIST = new ArrayList<>();
32
33 static {
34 try {
35
36
37
38
39
40 org.apache.commons.math3.random.RandomGenerator unusedRng = null;
41
42
43
44
45 final int trialsBinomial = 20;
46 final double probSuccessBinomial = 0.67;
47 add(LIST, new org.apache.commons.math3.distribution.BinomialDistribution(unusedRng, trialsBinomial, probSuccessBinomial),
48 MathArrays.sequence(8, 9, 1),
49 RandomAssert.createRNG());
50 add(LIST, new org.apache.commons.math3.distribution.BinomialDistribution(unusedRng, trialsBinomial, probSuccessBinomial),
51
52 MathArrays.sequence(8, 9, 1),
53 MarsagliaTsangWangDiscreteSampler.Binomial.of(RandomAssert.createRNG(), trialsBinomial, probSuccessBinomial));
54
55 add(LIST, new org.apache.commons.math3.distribution.BinomialDistribution(unusedRng, trialsBinomial, 1 - probSuccessBinomial),
56
57 MathArrays.sequence(8, 4, 1),
58 MarsagliaTsangWangDiscreteSampler.Binomial.of(RandomAssert.createRNG(), trialsBinomial, 1 - probSuccessBinomial));
59
60
61 final double probSuccessGeometric = 0.21;
62 add(LIST, new org.apache.commons.math3.distribution.GeometricDistribution(unusedRng, probSuccessGeometric),
63 MathArrays.sequence(10, 0, 1),
64 RandomAssert.createRNG());
65
66 add(LIST, new org.apache.commons.math3.distribution.GeometricDistribution(unusedRng, probSuccessGeometric),
67 MathArrays.sequence(10, 0, 1),
68 GeometricSampler.of(RandomAssert.createRNG(), probSuccessGeometric));
69
70
71 final int popSizeHyper = 34;
72 final int numSuccessesHyper = 11;
73 final int sampleSizeHyper = 12;
74 add(LIST, new org.apache.commons.math3.distribution.HypergeometricDistribution(unusedRng, popSizeHyper, numSuccessesHyper, sampleSizeHyper),
75 MathArrays.sequence(10, 0, 1),
76 RandomAssert.createRNG());
77
78
79 final int numSuccessesPascal = 6;
80 final double probSuccessPascal = 0.2;
81 add(LIST, new org.apache.commons.math3.distribution.PascalDistribution(unusedRng, numSuccessesPascal, probSuccessPascal),
82 MathArrays.sequence(18, 1, 1),
83 RandomAssert.createRNG());
84
85
86 final int loUniform = -3;
87 final int hiUniform = 4;
88 add(LIST, new org.apache.commons.math3.distribution.UniformIntegerDistribution(unusedRng, loUniform, hiUniform),
89 MathArrays.sequence(8, -3, 1),
90 RandomAssert.createRNG());
91
92 add(LIST, new org.apache.commons.math3.distribution.UniformIntegerDistribution(unusedRng, loUniform, hiUniform),
93 MathArrays.sequence(8, -3, 1),
94 DiscreteUniformSampler.of(RandomAssert.createRNG(), loUniform, hiUniform));
95
96 final int halfMax = Integer.MAX_VALUE / 2;
97 final int hiLargeUniform = halfMax + 10;
98 final int loLargeUniform = -hiLargeUniform;
99 add(LIST, new org.apache.commons.math3.distribution.UniformIntegerDistribution(unusedRng, loLargeUniform, hiLargeUniform),
100 MathArrays.sequence(20, -halfMax, halfMax / 10),
101 DiscreteUniformSampler.of(RandomAssert.createRNG(), loLargeUniform, hiLargeUniform));
102
103 final int rangeNonPowerOf2Uniform = 11;
104 final int hiNonPowerOf2Uniform = loUniform + rangeNonPowerOf2Uniform;
105 add(LIST, new org.apache.commons.math3.distribution.UniformIntegerDistribution(unusedRng, loUniform, hiNonPowerOf2Uniform),
106 MathArrays.sequence(rangeNonPowerOf2Uniform, -3, 1),
107 DiscreteUniformSampler.of(RandomAssert.createRNG(), loUniform, hiNonPowerOf2Uniform));
108
109
110 final int numElementsZipf = 5;
111 final double exponentZipf = 2.345;
112 add(LIST, new org.apache.commons.math3.distribution.ZipfDistribution(unusedRng, numElementsZipf, exponentZipf),
113 MathArrays.sequence(5, 1, 1),
114 RandomAssert.createRNG());
115
116 add(LIST, new org.apache.commons.math3.distribution.ZipfDistribution(unusedRng, numElementsZipf, exponentZipf),
117 MathArrays.sequence(5, 1, 1),
118 RejectionInversionZipfSampler.of(RandomAssert.createRNG(), numElementsZipf, exponentZipf));
119
120 final double exponentCloseToOneZipf = 1 - 1e-10;
121 add(LIST, new org.apache.commons.math3.distribution.ZipfDistribution(unusedRng, numElementsZipf, exponentCloseToOneZipf),
122 MathArrays.sequence(5, 1, 1),
123 RejectionInversionZipfSampler.of(RandomAssert.createRNG(), numElementsZipf, exponentCloseToOneZipf));
124
125 add(LIST, MathArrays.sequence(5, 1, 1), new double[] {0.2, 0.2, 0.2, 0.2, 0.2},
126 RejectionInversionZipfSampler.of(RandomAssert.createRNG(), numElementsZipf, 0.0));
127
128
129 final double epsilonPoisson = org.apache.commons.math3.distribution.PoissonDistribution.DEFAULT_EPSILON;
130 final int maxIterationsPoisson = org.apache.commons.math3.distribution.PoissonDistribution.DEFAULT_MAX_ITERATIONS;
131 final double meanPoisson = 3.21;
132 add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, meanPoisson, epsilonPoisson, maxIterationsPoisson),
133 MathArrays.sequence(10, 0, 1),
134 RandomAssert.createRNG());
135
136 add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, meanPoisson, epsilonPoisson, maxIterationsPoisson),
137 MathArrays.sequence(10, 0, 1),
138 PoissonSampler.of(RandomAssert.createRNG(), meanPoisson));
139
140 add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, meanPoisson, epsilonPoisson, maxIterationsPoisson),
141 MathArrays.sequence(10, 0, 1),
142 SmallMeanPoissonSampler.of(RandomAssert.createRNG(), meanPoisson));
143 add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, meanPoisson, epsilonPoisson, maxIterationsPoisson),
144 MathArrays.sequence(10, 0, 1),
145 KempSmallMeanPoissonSampler.of(RandomAssert.createRNG(), meanPoisson));
146 add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, meanPoisson, epsilonPoisson, maxIterationsPoisson),
147 MathArrays.sequence(10, 0, 1),
148 MarsagliaTsangWangDiscreteSampler.Poisson.of(RandomAssert.createRNG(), meanPoisson));
149
150
151 add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, meanPoisson, epsilonPoisson, maxIterationsPoisson),
152 MathArrays.sequence(10, 0, 1),
153 LargeMeanPoissonSampler.of(RandomAssert.createRNG(), meanPoisson));
154
155 final double largeMeanPoisson = 67.89;
156 add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, largeMeanPoisson, epsilonPoisson, maxIterationsPoisson),
157 MathArrays.sequence(50, (int) (largeMeanPoisson - 25), 1),
158 PoissonSampler.of(RandomAssert.createRNG(), largeMeanPoisson));
159
160 add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, largeMeanPoisson, epsilonPoisson, maxIterationsPoisson),
161 MathArrays.sequence(50, (int) (largeMeanPoisson - 25), 1),
162 LargeMeanPoissonSampler.of(RandomAssert.createRNG(), largeMeanPoisson));
163 add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, largeMeanPoisson, epsilonPoisson, maxIterationsPoisson),
164 MathArrays.sequence(50, (int) (largeMeanPoisson - 25), 1),
165 MarsagliaTsangWangDiscreteSampler.Poisson.of(RandomAssert.createRNG(), largeMeanPoisson));
166
167 final double veryLargeMeanPoisson = 543.21;
168 add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, veryLargeMeanPoisson, epsilonPoisson, maxIterationsPoisson),
169 MathArrays.sequence(100, (int) (veryLargeMeanPoisson - 50), 1),
170 PoissonSampler.of(RandomAssert.createRNG(), veryLargeMeanPoisson));
171
172 add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, veryLargeMeanPoisson, epsilonPoisson, maxIterationsPoisson),
173 MathArrays.sequence(100, (int) (veryLargeMeanPoisson - 50), 1),
174 LargeMeanPoissonSampler.of(RandomAssert.createRNG(), veryLargeMeanPoisson));
175 add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, veryLargeMeanPoisson, epsilonPoisson, maxIterationsPoisson),
176 MathArrays.sequence(100, (int) (veryLargeMeanPoisson - 50), 1),
177 MarsagliaTsangWangDiscreteSampler.Poisson.of(RandomAssert.createRNG(), veryLargeMeanPoisson));
178
179
180 final int[] discretePoints = {0, 1, 2, 3, 4};
181 final double[] discreteProbabilities = {0.1, 0.2, 0.3, 0.4, 0.5};
182 final long[] discreteFrequencies = {1, 2, 3, 4, 5};
183 add(LIST, discretePoints, discreteProbabilities,
184 MarsagliaTsangWangDiscreteSampler.Enumerated.of(RandomAssert.createRNG(), discreteProbabilities));
185 add(LIST, discretePoints, discreteProbabilities,
186 GuideTableDiscreteSampler.of(RandomAssert.createRNG(), discreteProbabilities));
187 add(LIST, discretePoints, discreteProbabilities,
188 AliasMethodDiscreteSampler.of(RandomAssert.createRNG(), discreteProbabilities));
189 add(LIST, discretePoints, discreteProbabilities,
190 FastLoadedDiceRollerDiscreteSampler.of(RandomAssert.createRNG(), discreteFrequencies));
191 add(LIST, discretePoints, discreteProbabilities,
192 FastLoadedDiceRollerDiscreteSampler.of(RandomAssert.createRNG(), discreteProbabilities));
193 } catch (Exception e) {
194
195 System.err.println("Unexpected exception while creating the list of samplers: " + e);
196 e.printStackTrace(System.err);
197
198 throw new RuntimeException(e);
199 }
200 }
201
202
203
204
205 private DiscreteSamplersList() {}
206
207
208
209
210
211
212
213 private static void add(List<DiscreteSamplerTestData> list,
214 final org.apache.commons.math3.distribution.IntegerDistribution dist,
215 int[] points,
216 UniformRandomProvider rng) {
217 final DiscreteSampler inverseMethodSampler =
218 InverseTransformDiscreteSampler.of(rng,
219 new DiscreteInverseCumulativeProbabilityFunction() {
220 @Override
221 public int inverseCumulativeProbability(double p) {
222 return dist.inverseCumulativeProbability(p);
223 }
224 @Override
225 public String toString() {
226 return dist.toString();
227 }
228 });
229 list.add(new DiscreteSamplerTestData(inverseMethodSampler,
230 points,
231 getProbabilities(dist, points)));
232 }
233
234
235
236
237
238
239
240 private static void add(List<DiscreteSamplerTestData> list,
241 final org.apache.commons.math3.distribution.IntegerDistribution dist,
242 int[] points,
243 final DiscreteSampler sampler) {
244 list.add(new DiscreteSamplerTestData(sampler,
245 points,
246 getProbabilities(dist, points)));
247 }
248
249
250
251
252
253
254
255 private static void add(List<DiscreteSamplerTestData> list,
256 int[] points,
257 final double[] probabilities,
258 final DiscreteSampler sampler) {
259 list.add(new DiscreteSamplerTestData(sampler,
260 points,
261 probabilities));
262 }
263
264
265
266
267
268
269
270 public static Iterable<DiscreteSamplerTestData> list() {
271 return Collections.unmodifiableList(LIST);
272 }
273
274
275
276
277
278
279 private static double[] getProbabilities(org.apache.commons.math3.distribution.IntegerDistribution dist,
280 int[] points) {
281 final int len = points.length;
282 final double[] prob = new double[len];
283 for (int i = 0; i < len; i++) {
284 prob[i] = dist instanceof org.apache.commons.math3.distribution.UniformIntegerDistribution ?
285 getProbability((org.apache.commons.math3.distribution.UniformIntegerDistribution) dist) :
286 dist.probability(points[i]);
287
288 if (prob[i] < 0) {
289 throw new IllegalStateException(dist + ": p < 0 (at " + points[i] + ", p=" + prob[i]);
290 }
291 }
292 return prob;
293 }
294
295
296
297
298 private static double getProbability(org.apache.commons.math3.distribution.UniformIntegerDistribution dist) {
299 return 1 / ((double) dist.getSupportUpperBound() - (double) dist.getSupportLowerBound() + 1);
300 }
301 }