View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.rng.sampling.shape;
18  
19  import java.util.Arrays;
20  import java.util.function.DoubleUnaryOperator;
21  import org.apache.commons.math3.stat.inference.ChiSquareTest;
22  import org.apache.commons.rng.UniformRandomProvider;
23  import org.apache.commons.rng.core.source64.SplitMix64;
24  import org.apache.commons.rng.sampling.RandomAssert;
25  import org.junit.jupiter.api.Assertions;
26  import org.junit.jupiter.api.Test;
27  
28  /**
29   * Test for {@link UnitBallSampler}.
30   */
31  class UnitBallSamplerTest {
32      /**
33       * Test a non-positive dimension.
34       */
35      @Test
36      void testInvalidDimensionThrows() {
37          final UniformRandomProvider rng = RandomAssert.seededRNG();
38          Assertions.assertThrows(IllegalArgumentException.class,
39              () -> UnitBallSampler.of(rng, 0));
40      }
41  
42      /**
43       * Test the distribution of points in one dimension.
44       */
45      @Test
46      void testDistribution1D() {
47          testDistributionND(1);
48      }
49  
50      /**
51       * Test the distribution of points in two dimensions.
52       */
53      @Test
54      void testDistribution2D() {
55          testDistributionND(2);
56      }
57  
58      /**
59       * Test the distribution of points in three dimensions.
60       */
61      @Test
62      void testDistribution3D() {
63          testDistributionND(3);
64      }
65  
66      /**
67       * Test the distribution of points in four dimensions.
68       */
69      @Test
70      void testDistribution4D() {
71          testDistributionND(4);
72      }
73  
74      /**
75       * Test the distribution of points in five dimensions.
76       */
77      @Test
78      void testDistribution5D() {
79          testDistributionND(5);
80      }
81  
82      /**
83       * Test the distribution of points in six dimensions.
84       */
85      @Test
86      void testDistribution6D() {
87          testDistributionND(6);
88      }
89  
90      /**
91       * Test the distribution of points in n dimensions. The output coordinates
92       * should be uniform in the unit n-ball. The unit n-ball is divided into inner
93       * n-balls. The radii of the internal n-balls are varied to ensure that successive layers
94       * have the same volume. This assigns each coordinate to an inner n-ball layer and an
95       * orthant using the sign bits of the coordinates. The number of samples in each bin
96       * should be the same.
97       *
98       * @see <a href="https://en.wikipedia.org/wiki/Volume_of_an_n-ball">Volume of an n-ball</a>
99       * @see <a href="https://en.wikipedia.org/wiki/Orthant">Orthant</a>
100      */
101     private static void testDistributionND(int dimension) {
102         // The number of inner layers and samples has been selected by trial and error using
103         // random seeds and multiple runs to ensure correctness of the test (i.e. it fails with
104         // approximately the fraction expected for the test p-value).
105         // A fixed seed is used to make the test suite robust.
106         final int layers = 10;
107         final int samplesPerBin = 20;
108         final int orthants = 1 << dimension;
109 
110         // Compute the radius for each layer to have the same volume.
111         final double volume = createVolumeFunction(dimension).applyAsDouble(1);
112         final DoubleUnaryOperator radius = createRadiusFunction(dimension);
113         final double[] r = new double[layers];
114         for (int i = 1; i < layers; i++) {
115             r[i - 1] = radius.applyAsDouble(volume * ((double) i / layers));
116         }
117         // The final radius should be 1.0. Any coordinates with a radius above 1
118         // should fail so explicitly set the value as 1.
119         r[layers - 1] = 1.0;
120 
121         // Expect a uniform distribution
122         final double[] expected = new double[layers * orthants];
123         final int samples = samplesPerBin * expected.length;
124         Arrays.fill(expected, (double) samples / layers);
125 
126         // Increase the loops to verify robustness
127         final UniformRandomProvider rng = RandomAssert.createRNG();
128         final UnitBallSampler sampler = UnitBallSampler.of(rng, dimension);
129         for (int loop = 0; loop < 1; loop++) {
130             // Assign each coordinate to a layer inside the ball and an orthant using the sign
131             final long[] observed = new long[layers * orthants];
132             NEXT:
133             for (int i = 0; i < samples; i++) {
134                 final double[] v = sampler.sample();
135                 final double length = length(v);
136                 for (int layer = 0; layer < layers; layer++) {
137                     if (length <= r[layer]) {
138                         final int orthant = orthant(v);
139                         observed[layer * orthants + orthant]++;
140                         continue NEXT;
141                     }
142                 }
143                 // Radius above 1
144                 Assertions.fail("Invalid sample length: " + length);
145             }
146             final double p = new ChiSquareTest().chiSquareTest(expected, observed);
147             Assertions.assertFalse(p < 0.001, () -> "p-value too small: " + p);
148         }
149     }
150 
151     /**
152      * Test the edge case where the normalisation sum to divide by is zero for 3D.
153      */
154     @Test
155     void testInvalidInverseNormalisation3D() {
156         testInvalidInverseNormalisationND(3);
157     }
158 
159     /**
160      * Test the edge case where the normalisation sum to divide by is zero for 4D.
161      */
162     @Test
163     void testInvalidInverseNormalisation4D() {
164         testInvalidInverseNormalisationND(4);
165     }
166 
167     /**
168      * Test the edge case where the normalisation sum to divide by is zero.
169      * This test requires generation of Gaussian samples with the value 0.
170      */
171     private static void testInvalidInverseNormalisationND(final int dimension) {
172         // Create a provider that will create a bad first sample but then recover.
173         // This checks recursion will return a good value.
174         final UniformRandomProvider bad = new SplitMix64(0x1a2b3cL) {
175             private int count = -2 * dimension;
176 
177             @Override
178             public long nextLong() {
179                 // Return enough zeros to create Gaussian samples of zero for all coordinates.
180                 return count++ < 0 ? 0 : super.nextLong();
181             }
182         };
183 
184         final double[] vector = UnitBallSampler.of(bad, dimension).sample();
185         Assertions.assertEquals(dimension, vector.length);
186         // A non-zero coordinate should occur with a SplitMix which returns 0 only once.
187         Assertions.assertNotEquals(0.0, length(vector));
188     }
189 
190     /**
191      * Test the SharedStateSampler implementation for 1D.
192      */
193     @Test
194     void testSharedStateSampler1D() {
195         testSharedStateSampler(1);
196     }
197 
198     /**
199      * Test the SharedStateSampler implementation for 2D.
200      */
201     @Test
202     void testSharedStateSampler2D() {
203         testSharedStateSampler(2);
204     }
205 
206     /**
207      * Test the SharedStateSampler implementation for 3D.
208      */
209     @Test
210     void testSharedStateSampler3D() {
211         testSharedStateSampler(3);
212     }
213 
214     /**
215      * Test the SharedStateSampler implementation for 4D.
216      */
217     @Test
218     void testSharedStateSampler4D() {
219         testSharedStateSampler(4);
220     }
221 
222     /**
223      * Test the SharedStateSampler implementation for the given dimension.
224      */
225     private static void testSharedStateSampler(int dimension) {
226         final UniformRandomProvider rng1 = RandomAssert.seededRNG();
227         final UniformRandomProvider rng2 = RandomAssert.seededRNG();
228         final UnitBallSampler sampler1 = UnitBallSampler.of(rng1, dimension);
229         final UnitBallSampler sampler2 = sampler1.withUniformRandomProvider(rng2);
230         RandomAssert.assertProduceSameSequence(sampler1, sampler2);
231     }
232 
233     /**
234      * @return the length (L2-norm) of given vector.
235      */
236     private static double length(double[] vector) {
237         double total = 0;
238         for (double d : vector) {
239             total += d * d;
240         }
241         return Math.sqrt(total);
242     }
243 
244     /**
245      * Assign an orthant to the vector using the sign of each component.
246      * The i<sup>th</sup> bit is set in the orthant for the i<sup>th</sup> component
247      * if the component is negative.
248      *
249      * @return the orthant in the range [0, vector.length)
250      * @see <a href="https://en.wikipedia.org/wiki/Orthant">Orthant</a>
251      */
252     private static int orthant(double[] vector) {
253         int orthant = 0;
254         for (int i = 0; i < vector.length; i++) {
255             if (vector[i] < 0) {
256                 orthant |= 1 << i;
257             }
258         }
259         return orthant;
260     }
261 
262     /**
263      * Check the n-ball volume functions can map the radius to the volume and back.
264      * These functions are used to divide the n-ball into uniform volume bins to test sampling
265      * within the n-ball.
266      */
267     @Test
268     void checkVolumeFunctions() {
269         final double[] radii = {0, 0.1, 0.25, 0.5, 0.75, 1.0};
270         for (int n = 1; n <= 6; n++) {
271             final DoubleUnaryOperator volume = createVolumeFunction(n);
272             final DoubleUnaryOperator radius = createRadiusFunction(n);
273             for (final double r : radii) {
274                 Assertions.assertEquals(r, radius.applyAsDouble(volume.applyAsDouble(r)), 1e-10);
275             }
276         }
277     }
278 
279     /**
280      * Creates a function to compute the volume of a ball of the given dimension
281      * from the radius.
282      *
283      * @param dimension the dimension
284      * @return the volume function
285      * @see <a href="https://en.wikipedia.org/wiki/Volume_of_an_n-ball">Volume of an n-ball</a>
286      */
287     private static DoubleUnaryOperator createVolumeFunction(final int dimension) {
288         if (dimension == 1) {
289             return r -> r * 2;
290         } else if (dimension == 2) {
291             return r -> Math.PI * r * r;
292         } else if (dimension == 3) {
293             final double factor = 4 * Math.PI / 3;
294             return r -> factor * Math.pow(r, 3);
295         } else if (dimension == 4) {
296             final double factor = Math.PI * Math.PI / 2;
297             return r -> factor * Math.pow(r, 4);
298         } else if (dimension == 5) {
299             final double factor = 8 * Math.PI * Math.PI / 15;
300             return r -> factor * Math.pow(r, 5);
301         } else if (dimension == 6) {
302             final double factor = Math.pow(Math.PI, 3) / 6;
303             return r -> factor * Math.pow(r, 6);
304         }
305         throw new IllegalStateException("Unsupported dimension: " + dimension);
306     }
307 
308     /**
309      * Creates a function to compute the radius of a ball of the given dimension
310      * from the volume.
311      *
312      * @param dimension the dimension
313      * @return the radius function
314      * @see <a href="https://en.wikipedia.org/wiki/Volume_of_an_n-ball">Volume of an n-ball</a>
315      */
316     private static DoubleUnaryOperator createRadiusFunction(final int dimension) {
317         if (dimension == 1) {
318             return v -> v * 0.5;
319         } else if (dimension == 2) {
320             return v -> Math.sqrt(v / Math.PI);
321         } else if (dimension == 3) {
322             final double factor = 3.0 / (4 * Math.PI);
323             return v -> Math.cbrt(v * factor);
324         } else if (dimension == 4) {
325             final double factor = 2.0 / (Math.PI * Math.PI);
326             return v -> Math.pow(v * factor, 0.25);
327         } else if (dimension == 5) {
328             final double factor = 15.0 / (8 * Math.PI * Math.PI);
329             return v -> Math.pow(v * factor, 0.2);
330         } else if (dimension == 6) {
331             final double factor = 6.0 / Math.pow(Math.PI, 3);
332             return v -> Math.pow(v * factor, 1.0 / 6);
333         }
334         throw new IllegalStateException("Unsupported dimension: " + dimension);
335     }
336 }