1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.examples.jmh.sampling.distribution;
18
19 import org.junit.jupiter.api.Assertions;
20 import org.junit.jupiter.params.ParameterizedTest;
21 import org.junit.jupiter.params.provider.Arguments;
22 import org.junit.jupiter.params.provider.MethodSource;
23 import java.util.Arrays;
24 import java.util.function.Function;
25 import java.util.function.Supplier;
26 import java.util.stream.Stream;
27 import org.apache.commons.math3.distribution.AbstractRealDistribution;
28 import org.apache.commons.math3.distribution.ExponentialDistribution;
29 import org.apache.commons.math3.distribution.NormalDistribution;
30 import org.apache.commons.math3.stat.inference.ChiSquareTest;
31 import org.apache.commons.rng.RestorableUniformRandomProvider;
32 import org.apache.commons.rng.UniformRandomProvider;
33 import org.apache.commons.rng.sampling.distribution.ContinuousSampler;
34 import org.apache.commons.rng.simple.RandomSource;
35
36
37
38
39
40
41
42 class ZigguratSamplerTest {
43
44
45
46
47
48
49
50
51
52
53
54
55
56 private static final Long SEED = 0xd1342543de82ef95L;
57
58
59
60
61
62
63
64
65 private static Arguments args(String name) {
66
67
68 final Function<UniformRandomProvider, ContinuousSampler> factory =
69 rng -> ZigguratSamplerPerformance.Sources.createSampler(name, rng);
70 return Arguments.of(name, factory);
71 }
72
73
74
75
76
77
78
79
80
81 private static Stream<Arguments> gaussianSamplers() {
82
83 return Stream.of(
84 args(ZigguratSamplerPerformance.GAUSSIAN_128),
85 args(ZigguratSamplerPerformance.GAUSSIAN_256),
86 args(ZigguratSamplerPerformance.MOD_GAUSSIAN2),
87 args(ZigguratSamplerPerformance.MOD_GAUSSIAN_SIMPLE_OVERHANGS),
88 args(ZigguratSamplerPerformance.MOD_GAUSSIAN_INLINING),
89 args(ZigguratSamplerPerformance.MOD_GAUSSIAN_INLINING_SHIFT),
90 args(ZigguratSamplerPerformance.MOD_GAUSSIAN_INLINING_SIMPLE_OVERHANGS),
91 args(ZigguratSamplerPerformance.MOD_GAUSSIAN_INT_MAP),
92 args(ZigguratSamplerPerformance.MOD_GAUSSIAN_E_MAX_TABLE),
93 args(ZigguratSamplerPerformance.MOD_GAUSSIAN_E_MAX_2),
94 args(ZigguratSamplerPerformance.MOD_GAUSSIAN_TERNARY),
95 args(ZigguratSamplerPerformance.MOD_GAUSSIAN_512));
96 }
97
98
99
100
101
102
103
104
105
106 private static Stream<Arguments> exponentialSamplers() {
107
108 return Stream.of(
109 args(ZigguratSamplerPerformance.EXPONENTIAL),
110 args(ZigguratSamplerPerformance.MOD_EXPONENTIAL2),
111 args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_SIMPLE_OVERHANGS),
112 args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_INLINING),
113 args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_LOOP),
114 args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_LOOP2),
115 args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_RECURSION),
116 args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_INT_MAP),
117 args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_E_MAX_TABLE),
118 args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_E_MAX_2),
119 args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_TERNARY),
120 args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_TERNARY_SUBTRACT),
121 args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_512));
122 }
123
124
125
126
127
128
129
130
131
132
133
134 private static AbstractRealDistribution createGaussianDistribution() {
135 return new NormalDistribution(null, 0.0, 1.0);
136 }
137
138
139
140
141
142
143 private static AbstractRealDistribution createExponentialDistribution() {
144 return new ExponentialDistribution(null, 1.0);
145 }
146
147
148
149
150
151
152
153
154 @ParameterizedTest(name = "{index} => {0}")
155 @MethodSource("gaussianSamplers")
156 void testGaussianSamplesWithQuantiles(String name, Function<UniformRandomProvider, ContinuousSampler> factory) {
157 final int bins = 2000;
158 final AbstractRealDistribution dist = createGaussianDistribution();
159 final double[] quantiles = new double[bins];
160 for (int i = 0; i < bins; i++) {
161 quantiles[i] = dist.inverseCumulativeProbability((i + 1.0) / bins);
162 }
163 testSamples(quantiles, factory, ZigguratSamplerTest::createGaussianDistribution,
164
165
166 new double[] {0, 0.2},
167 new double[] {-0.35, -0.1},
168
169 new double[] {-0.1, 0.1},
170 new double[] {-0.4, 0.6},
171
172 new double[] {-1.1, -0.9},
173
174 new double[] {2.1, 2.5},
175
176 new double[] {2.5, 8});
177 }
178
179
180
181
182
183
184
185
186 @ParameterizedTest(name = "{index} => {0}")
187 @MethodSource("gaussianSamplers")
188 void testGaussianSamplesWithUniformValues(String name, Function<UniformRandomProvider, ContinuousSampler> factory) {
189 final int bins = 2000;
190 final double[] values = new double[bins];
191 final double minx = -8;
192 final double maxx = 8;
193
194 for (int i = 0; i < bins; i++) {
195 values[i] = minx + (maxx - minx) * (i + 1.0) / bins;
196 }
197
198 values[bins - 1] = Double.POSITIVE_INFINITY;
199 testSamples(values, factory, ZigguratSamplerTest::createGaussianDistribution,
200
201
202 new double[] {0, 0.2},
203 new double[] {-0.35, -0.1},
204
205 new double[] {-0.1, 0.1},
206 new double[] {-0.4, 0.6},
207
208 new double[] {-1.01, -0.99},
209 new double[] {0.98, 1.03},
210
211 new double[] {1.03, 1.05},
212
213 new double[] {3.6, 3.8},
214 new double[] {3.7, 8});
215 }
216
217
218
219
220
221
222
223 @ParameterizedTest(name = "{index} => {0}")
224 @MethodSource("exponentialSamplers")
225 void testExponentialSamplesWithQuantiles(String name, Function<UniformRandomProvider, ContinuousSampler> factory) {
226 final int bins = 2000;
227 final AbstractRealDistribution dist = createExponentialDistribution();
228 final double[] quantiles = new double[bins];
229 for (int i = 0; i < bins; i++) {
230 quantiles[i] = dist.inverseCumulativeProbability((i + 1.0) / bins);
231 }
232 testSamples(quantiles, factory, ZigguratSamplerTest::createExponentialDistribution,
233
234 new double[] {0, 0.1},
235 new double[] {0.05, 0.15},
236
237 new double[] {0.9, 1.1},
238
239 new double[] {1.5, 12});
240 }
241
242
243
244
245
246
247
248 @ParameterizedTest(name = "{index} => {0}")
249 @MethodSource("exponentialSamplers")
250 void testExponentialSamplesWithUniformValues(String name, Function<UniformRandomProvider, ContinuousSampler> factory) {
251 final int bins = 2000;
252 final double[] values = new double[bins];
253 final double minx = 0;
254
255 final double maxx = 12;
256
257 for (int i = 0; i < bins; i++) {
258 values[i] = minx + (maxx - minx) * (i + 1.0) / bins;
259 }
260
261 values[bins - 1] = Double.POSITIVE_INFINITY;
262
263 testSamples(values, factory, ZigguratSamplerTest::createExponentialDistribution,
264
265 new double[] {0, 0.1},
266 new double[] {0.05, 0.15},
267
268 new double[] {0.9, 1.1},
269
270 new double[] {7.5, 7.7},
271 new double[] {7.7, 12});
272 }
273
274
275
276
277
278
279
280
281
282
283
284 private static void testSamples(double[] values,
285 Function<UniformRandomProvider, ContinuousSampler> factory,
286 Supplier<AbstractRealDistribution> distribution,
287 double[]... ranges) {
288 final int bins = values.length;
289
290 final int samples = 10000000;
291 final long[] observed = new long[bins];
292 final RestorableUniformRandomProvider rng = RandomSource.XO_SHI_RO_128_PP.create(SEED);
293 final ContinuousSampler sampler = factory.apply(rng);
294 for (int i = 0; i < samples; i++) {
295 final double x = sampler.sample();
296 final int index = findIndex(values, x);
297 observed[index]++;
298 }
299
300
301 final AbstractRealDistribution dist = distribution.get();
302 final double[] expected = new double[bins];
303 double x0 = Double.NEGATIVE_INFINITY;
304 for (int i = 0; i < bins; i++) {
305 final double x1 = values[i];
306 expected[i] = dist.probability(x0, x1);
307 x0 = x1;
308 }
309
310 final double significanceLevel = 0.001;
311
312 final double lowerBound = dist.getSupportLowerBound();
313
314 final ChiSquareTest chiSquareTest = new ChiSquareTest();
315
316 final double pValue = chiSquareTest.chiSquareTest(expected, observed);
317 Assertions.assertFalse(pValue < 0.001,
318 () -> String.format("(%s <= x < %s) Chi-square p-value = %s",
319 lowerBound, values[bins - 1], pValue));
320
321
322 for (final double[] range : ranges) {
323 final int min = findIndex(values, range[0]);
324 final int max = findIndex(values, range[1]);
325
326 if (max - min + 1 < 2) {
327
328
329
330 Assertions.fail("Invalid range: " + Arrays.toString(range));
331 }
332 final long[] observed2 = Arrays.copyOfRange(observed, min, max + 1);
333 final double[] expected2 = Arrays.copyOfRange(expected, min, max + 1);
334 final double pValueB = chiSquareTest.chiSquareTest(expected2, observed2);
335 Assertions.assertFalse(pValueB < significanceLevel,
336 () -> String.format("(%s <= x < %s) Chi-square p-value = %s",
337 min == 0 ? lowerBound : values[min - 1], values[max], pValueB));
338 }
339 }
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355 private static int findIndex(double[] data, double x) {
356 int low = 0;
357 int high = data.length - 1;
358
359
360 while (low <= high) {
361 final int mid = (low + high) >>> 1;
362 final double midVal = data[mid];
363
364 if (x < midVal) {
365
366 high = mid - 1;
367 } else {
368
369 low = mid + 1;
370 }
371 }
372
373 Assertions.assertTrue(x < data[low]);
374 if (low != 0) {
375 Assertions.assertTrue(x >= data[low - 1]);
376 }
377 return low;
378 }
379 }