View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.rng.examples.jmh.sampling.distribution;
18  
19  import org.junit.jupiter.api.Assertions;
20  import org.junit.jupiter.params.ParameterizedTest;
21  import org.junit.jupiter.params.provider.Arguments;
22  import org.junit.jupiter.params.provider.MethodSource;
23  import java.util.Arrays;
24  import java.util.function.Function;
25  import java.util.function.Supplier;
26  import java.util.stream.Stream;
27  import org.apache.commons.math3.distribution.AbstractRealDistribution;
28  import org.apache.commons.math3.distribution.ExponentialDistribution;
29  import org.apache.commons.math3.distribution.NormalDistribution;
30  import org.apache.commons.math3.stat.inference.ChiSquareTest;
31  import org.apache.commons.rng.RestorableUniformRandomProvider;
32  import org.apache.commons.rng.UniformRandomProvider;
33  import org.apache.commons.rng.sampling.distribution.ContinuousSampler;
34  import org.apache.commons.rng.simple.RandomSource;
35  
36  /**
37   * Test for ziggurat samplers in the {@link ZigguratSamplerPerformance} class.
38   *
39   * <p>This test is copied from the {@code commons-rng-sampling} module to ensure all implementations
40   * correctly sample from the distribution.
41   */
42  class ZigguratSamplerTest {
43  
44      /**
45       * The seed for the RNG used in the distribution sampling tests.
46       *
47       * <p>This has been chosen to allow the test to pass with all generators.
48       * Set to null test with a random seed.
49       *
50       * <p>Note that the p-value of the chi-square test is 0.001. There are multiple assertions
51       * per test and multiple samplers. The total number of chi-square tests is above 100
52       * and failure of a chosen random seed on a few tests is common. When using a random
53       * seed re-run the test multiple times. Systematic failure of the same sampler
54       * should be investigated further.
55       */
56      private static final Long SEED = 0xd1342543de82ef95L;
57  
58      /**
59       * Create arguments with the name of the factory.
60       *
61       * @param name Name of the factory
62       * @param factory Factory to create the sampler
63       * @return the arguments
64       */
65      private static Arguments args(String name) {
66          // Create the factory.
67          // Here we delegate to the static method used to create all the samplers for testing.
68          final Function<UniformRandomProvider, ContinuousSampler> factory =
69              rng -> ZigguratSamplerPerformance.Sources.createSampler(name, rng);
70          return Arguments.of(name, factory);
71      }
72  
73      /**
74       * Create a stream of constructors of a Gaussian sampler.
75       *
76       * <p>Note: This method exists to allow this test to be duplicated in the examples JMH
77       * module where many implementations are tested.
78       *
79       * @return the stream of constructors
80       */
81      private static Stream<Arguments> gaussianSamplers() {
82          // Test all but MOD_GAUSSIAN (tested in the common-rng-sampling module)
83          return Stream.of(
84              args(ZigguratSamplerPerformance.GAUSSIAN_128),
85              args(ZigguratSamplerPerformance.GAUSSIAN_256),
86              args(ZigguratSamplerPerformance.MOD_GAUSSIAN2),
87              args(ZigguratSamplerPerformance.MOD_GAUSSIAN_SIMPLE_OVERHANGS),
88              args(ZigguratSamplerPerformance.MOD_GAUSSIAN_INLINING),
89              args(ZigguratSamplerPerformance.MOD_GAUSSIAN_INLINING_SHIFT),
90              args(ZigguratSamplerPerformance.MOD_GAUSSIAN_INLINING_SIMPLE_OVERHANGS),
91              args(ZigguratSamplerPerformance.MOD_GAUSSIAN_INT_MAP),
92              args(ZigguratSamplerPerformance.MOD_GAUSSIAN_E_MAX_TABLE),
93              args(ZigguratSamplerPerformance.MOD_GAUSSIAN_E_MAX_2),
94              args(ZigguratSamplerPerformance.MOD_GAUSSIAN_TERNARY),
95              args(ZigguratSamplerPerformance.MOD_GAUSSIAN_512));
96      }
97  
98      /**
99       * Create a stream of constructors of an exponential sampler.
100      *
101      * <p>Note: This method exists to allow this test to be duplicated in the examples JMH
102      * module where many implementations are tested.
103      *
104      * @return the stream of constructors
105      */
106     private static Stream<Arguments> exponentialSamplers() {
107         // Test all but MOD_EXPONENTIAL (tested in the common-rng-sampling module)
108         return Stream.of(
109                 args(ZigguratSamplerPerformance.EXPONENTIAL),
110                 args(ZigguratSamplerPerformance.MOD_EXPONENTIAL2),
111                 args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_SIMPLE_OVERHANGS),
112                 args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_INLINING),
113                 args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_LOOP),
114                 args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_LOOP2),
115                 args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_RECURSION),
116                 args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_INT_MAP),
117                 args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_E_MAX_TABLE),
118                 args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_E_MAX_2),
119                 args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_TERNARY),
120                 args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_TERNARY_SUBTRACT),
121                 args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_512));
122     }
123 
124     // -------------------------------------------------------------------------
125     // All code below here is copied from:
126     // commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/ZigguratSamplerTest.java
127     // -------------------------------------------------------------------------
128 
129     /**
130      * Creates the gaussian distribution.
131      *
132      * @return the distribution
133      */
134     private static AbstractRealDistribution createGaussianDistribution() {
135         return new NormalDistribution(null, 0.0, 1.0);
136     }
137 
138     /**
139      * Creates the exponential distribution.
140      *
141      * @return the distribution
142      */
143     private static AbstractRealDistribution createExponentialDistribution() {
144         return new ExponentialDistribution(null, 1.0);
145     }
146 
147     /**
148      * Test Gaussian samples using a large number of bins based on uniformly spaced
149      * quantiles. Added for RNG-159.
150      *
151      * @param name Name of the sampler
152      * @param factory Factory to create the sampler
153      */
154     @ParameterizedTest(name = "{index} => {0}")
155     @MethodSource("gaussianSamplers")
156     void testGaussianSamplesWithQuantiles(String name, Function<UniformRandomProvider, ContinuousSampler> factory) {
157         final int bins = 2000;
158         final AbstractRealDistribution dist = createGaussianDistribution();
159         final double[] quantiles = new double[bins];
160         for (int i = 0; i < bins; i++) {
161             quantiles[i] = dist.inverseCumulativeProbability((i + 1.0) / bins);
162         }
163         testSamples(quantiles, factory, ZigguratSamplerTest::createGaussianDistribution,
164             // Test positive and negative ranges to check symmetry.
165             // Smallest layer is 0.292 (convex region)
166             new double[] {0, 0.2},
167             new double[] {-0.35, -0.1},
168             // Around the mean
169             new double[] {-0.1, 0.1},
170             new double[] {-0.4, 0.6},
171             // Inflection point at x=1
172             new double[] {-1.1, -0.9},
173             // A concave region
174             new double[] {2.1, 2.5},
175             // Tail = 3.64
176             new double[] {2.5, 8});
177     }
178 
179     /**
180      * Test Gaussian samples using a large number of bins uniformly spaced in a range.
181      * Added for RNG-159.
182      *
183      * @param name Name of the sampler
184      * @param factory Factory to create the sampler
185      */
186     @ParameterizedTest(name = "{index} => {0}")
187     @MethodSource("gaussianSamplers")
188     void testGaussianSamplesWithUniformValues(String name, Function<UniformRandomProvider, ContinuousSampler> factory) {
189         final int bins = 2000;
190         final double[] values = new double[bins];
191         final double minx = -8;
192         final double maxx = 8;
193         // Bin width = 16 / 2000 = 0.008
194         for (int i = 0; i < bins; i++) {
195             values[i] = minx + (maxx - minx) * (i + 1.0) / bins;
196         }
197         // Ensure upper bound is the support limit
198         values[bins - 1] = Double.POSITIVE_INFINITY;
199         testSamples(values, factory, ZigguratSamplerTest::createGaussianDistribution,
200             // Test positive and negative ranges to check symmetry.
201             // Smallest layer is 0.292 (convex region)
202             new double[] {0, 0.2},
203             new double[] {-0.35, -0.1},
204             // Around the mean
205             new double[] {-0.1, 0.1},
206             new double[] {-0.4, 0.6},
207             // Inflection point at x=1
208             new double[] {-1.01, -0.99},
209             new double[] {0.98, 1.03},
210             // A concave region
211             new double[] {1.03, 1.05},
212             // Tail = 3.64
213             new double[] {3.6, 3.8},
214             new double[] {3.7, 8});
215     }
216 
217     /**
218      * Test exponential samples using a large number of bins based on uniformly spaced quantiles.
219      *
220      * @param name Name of the sampler
221      * @param factory Factory to create the sampler
222      */
223     @ParameterizedTest(name = "{index} => {0}")
224     @MethodSource("exponentialSamplers")
225     void testExponentialSamplesWithQuantiles(String name, Function<UniformRandomProvider, ContinuousSampler> factory) {
226         final int bins = 2000;
227         final AbstractRealDistribution dist = createExponentialDistribution();
228         final double[] quantiles = new double[bins];
229         for (int i = 0; i < bins; i++) {
230             quantiles[i] = dist.inverseCumulativeProbability((i + 1.0) / bins);
231         }
232         testSamples(quantiles, factory, ZigguratSamplerTest::createExponentialDistribution,
233             // Smallest layer is 0.122
234             new double[] {0, 0.1},
235             new double[] {0.05, 0.15},
236             // Around the mean
237             new double[] {0.9, 1.1},
238             // Tail = 7.57
239             new double[] {1.5, 12});
240     }
241 
242     /**
243      * Test exponential samples using a large number of bins uniformly spaced in a range.
244      *
245      * @param name Name of the sampler
246      * @param factory Factory to create the sampler
247      */
248     @ParameterizedTest(name = "{index} => {0}")
249     @MethodSource("exponentialSamplers")
250     void testExponentialSamplesWithUniformValues(String name, Function<UniformRandomProvider, ContinuousSampler> factory) {
251         final int bins = 2000;
252         final double[] values = new double[bins];
253         final double minx = 0;
254         // Enter the tail of the distribution
255         final double maxx = 12;
256         // Bin width = 12 / 2000 = 0.006
257         for (int i = 0; i < bins; i++) {
258             values[i] = minx + (maxx - minx) * (i + 1.0) / bins;
259         }
260         // Ensure upper bound is the support limit
261         values[bins - 1] = Double.POSITIVE_INFINITY;
262 
263         testSamples(values, factory,  ZigguratSamplerTest::createExponentialDistribution,
264             // Smallest layer is 0.122
265             new double[] {0, 0.1},
266             new double[] {0.05, 0.15},
267             // Around the mean
268             new double[] {0.9, 1.1},
269             // Tail = 7.57
270             new double[] {7.5, 7.7},
271             new double[] {7.7, 12});
272     }
273 
274     /**
275      * Test samples using the provided bins. Values correspond to the bin upper
276      * limit. It is assumed the values span most of the distribution. Additional
277      * tests are performed using a region of the distribution sampled.
278      *
279      * @param values Bin upper limits
280      * @param factory Factory to create the sampler
281      * @param distribution The distribution under test
282      * @param ranges Ranges of the distribution to test
283      */
284     private static void testSamples(double[] values,
285                                     Function<UniformRandomProvider, ContinuousSampler> factory,
286                                     Supplier<AbstractRealDistribution> distribution,
287                                     double[]... ranges) {
288         final int bins = values.length;
289 
290         final int samples = 10000000;
291         final long[] observed = new long[bins];
292         final RestorableUniformRandomProvider rng = RandomSource.XO_SHI_RO_128_PP.create(SEED);
293         final ContinuousSampler sampler = factory.apply(rng);
294         for (int i = 0; i < samples; i++) {
295             final double x = sampler.sample();
296             final int index = findIndex(values, x);
297             observed[index]++;
298         }
299 
300         // Compute expected
301         final AbstractRealDistribution dist = distribution.get();
302         final double[] expected = new double[bins];
303         double x0 = Double.NEGATIVE_INFINITY;
304         for (int i = 0; i < bins; i++) {
305             final double x1 = values[i];
306             expected[i] = dist.probability(x0, x1);
307             x0 = x1;
308         }
309 
310         final double significanceLevel = 0.001;
311 
312         final double lowerBound = dist.getSupportLowerBound();
313 
314         final ChiSquareTest chiSquareTest = new ChiSquareTest();
315         // Pass if we cannot reject null hypothesis that the distributions are the same.
316         final double pValue = chiSquareTest.chiSquareTest(expected, observed);
317         Assertions.assertFalse(pValue < 0.001,
318             () -> String.format("(%s <= x < %s) Chi-square p-value = %s",
319                                 lowerBound, values[bins - 1], pValue));
320 
321         // Test regions of the ziggurat.
322         for (final double[] range : ranges) {
323             final int min = findIndex(values, range[0]);
324             final int max = findIndex(values, range[1]);
325             // Must have a range of 2
326             if (max - min + 1 < 2) {
327                 // This will probably occur if the quantiles test uses too small a range
328                 // for the tail. The tail is so far into the CDF that a single bin is
329                 // often used to represent it.
330                 Assertions.fail("Invalid range: " + Arrays.toString(range));
331             }
332             final long[] observed2 = Arrays.copyOfRange(observed, min, max + 1);
333             final double[] expected2 = Arrays.copyOfRange(expected, min, max + 1);
334             final double pValueB = chiSquareTest.chiSquareTest(expected2, observed2);
335             Assertions.assertFalse(pValueB < significanceLevel,
336                 () -> String.format("(%s <= x < %s) Chi-square p-value = %s",
337                                     min == 0 ? lowerBound : values[min - 1], values[max], pValueB));
338         }
339     }
340 
341     /**
342      * Find the index of the value in the data such that:
343      * <pre>
344      * data[index - 1] <= x < data[index]
345      * </pre>
346      *
347      * <p>This is a specialised binary search that assumes the bounds of the data are the
348      * extremes of the support, and the upper support is infinite. Thus an index cannot
349      * be returned as equal to the data length.
350      *
351      * @param data the data
352      * @param x the value
353      * @return the index
354      */
355     private static int findIndex(double[] data, double x) {
356         int low = 0;
357         int high = data.length - 1;
358 
359         // Bracket so that low is just above the value x
360         while (low <= high) {
361             final int mid = (low + high) >>> 1;
362             final double midVal = data[mid];
363 
364             if (x < midVal) {
365                 // Reduce search range
366                 high = mid - 1;
367             } else {
368                 // Set data[low] above the value
369                 low = mid + 1;
370             }
371         }
372         // Verify the index is correct
373         Assertions.assertTrue(x < data[low]);
374         if (low != 0) {
375             Assertions.assertTrue(x >= data[low - 1]);
376         }
377         return low;
378     }
379 }