1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.sampling.shape;
18
19 import java.util.Arrays;
20 import java.util.function.DoubleUnaryOperator;
21 import org.apache.commons.math3.stat.inference.ChiSquareTest;
22 import org.apache.commons.rng.UniformRandomProvider;
23 import org.apache.commons.rng.core.source64.SplitMix64;
24 import org.apache.commons.rng.sampling.RandomAssert;
25 import org.junit.jupiter.api.Assertions;
26 import org.junit.jupiter.api.Test;
27
28
29
30
31 class UnitBallSamplerTest {
32
33
34
35 @Test
36 void testInvalidDimensionThrows() {
37 final UniformRandomProvider rng = RandomAssert.seededRNG();
38 Assertions.assertThrows(IllegalArgumentException.class,
39 () -> UnitBallSampler.of(rng, 0));
40 }
41
42
43
44
45 @Test
46 void testDistribution1D() {
47 testDistributionND(1);
48 }
49
50
51
52
53 @Test
54 void testDistribution2D() {
55 testDistributionND(2);
56 }
57
58
59
60
61 @Test
62 void testDistribution3D() {
63 testDistributionND(3);
64 }
65
66
67
68
69 @Test
70 void testDistribution4D() {
71 testDistributionND(4);
72 }
73
74
75
76
77 @Test
78 void testDistribution5D() {
79 testDistributionND(5);
80 }
81
82
83
84
85 @Test
86 void testDistribution6D() {
87 testDistributionND(6);
88 }
89
90
91
92
93
94
95
96
97
98
99
100
101 private static void testDistributionND(int dimension) {
102
103
104
105
106 final int layers = 10;
107 final int samplesPerBin = 20;
108 final int orthants = 1 << dimension;
109
110
111 final double volume = createVolumeFunction(dimension).applyAsDouble(1);
112 final DoubleUnaryOperator radius = createRadiusFunction(dimension);
113 final double[] r = new double[layers];
114 for (int i = 1; i < layers; i++) {
115 r[i - 1] = radius.applyAsDouble(volume * ((double) i / layers));
116 }
117
118
119 r[layers - 1] = 1.0;
120
121
122 final double[] expected = new double[layers * orthants];
123 final int samples = samplesPerBin * expected.length;
124 Arrays.fill(expected, (double) samples / layers);
125
126
127 final UniformRandomProvider rng = RandomAssert.createRNG();
128 final UnitBallSampler sampler = UnitBallSampler.of(rng, dimension);
129 for (int loop = 0; loop < 1; loop++) {
130
131 final long[] observed = new long[layers * orthants];
132 NEXT:
133 for (int i = 0; i < samples; i++) {
134 final double[] v = sampler.sample();
135 final double length = length(v);
136 for (int layer = 0; layer < layers; layer++) {
137 if (length <= r[layer]) {
138 final int orthant = orthant(v);
139 observed[layer * orthants + orthant]++;
140 continue NEXT;
141 }
142 }
143
144 Assertions.fail("Invalid sample length: " + length);
145 }
146 final double p = new ChiSquareTest().chiSquareTest(expected, observed);
147 Assertions.assertFalse(p < 0.001, () -> "p-value too small: " + p);
148 }
149 }
150
151
152
153
154 @Test
155 void testInvalidInverseNormalisation3D() {
156 testInvalidInverseNormalisationND(3);
157 }
158
159
160
161
162 @Test
163 void testInvalidInverseNormalisation4D() {
164 testInvalidInverseNormalisationND(4);
165 }
166
167
168
169
170
171 private static void testInvalidInverseNormalisationND(final int dimension) {
172
173
174 final UniformRandomProvider bad = new SplitMix64(0x1a2b3cL) {
175 private int count = -2 * dimension;
176
177 @Override
178 public long nextLong() {
179
180 return count++ < 0 ? 0 : super.nextLong();
181 }
182 };
183
184 final double[] vector = UnitBallSampler.of(bad, dimension).sample();
185 Assertions.assertEquals(dimension, vector.length);
186
187 Assertions.assertNotEquals(0.0, length(vector));
188 }
189
190
191
192
193 @Test
194 void testSharedStateSampler1D() {
195 testSharedStateSampler(1);
196 }
197
198
199
200
201 @Test
202 void testSharedStateSampler2D() {
203 testSharedStateSampler(2);
204 }
205
206
207
208
209 @Test
210 void testSharedStateSampler3D() {
211 testSharedStateSampler(3);
212 }
213
214
215
216
217 @Test
218 void testSharedStateSampler4D() {
219 testSharedStateSampler(4);
220 }
221
222
223
224
225 private static void testSharedStateSampler(int dimension) {
226 final UniformRandomProvider rng1 = RandomAssert.seededRNG();
227 final UniformRandomProvider rng2 = RandomAssert.seededRNG();
228 final UnitBallSampler sampler1 = UnitBallSampler.of(rng1, dimension);
229 final UnitBallSampler sampler2 = sampler1.withUniformRandomProvider(rng2);
230 RandomAssert.assertProduceSameSequence(sampler1, sampler2);
231 }
232
233
234
235
236 private static double length(double[] vector) {
237 double total = 0;
238 for (double d : vector) {
239 total += d * d;
240 }
241 return Math.sqrt(total);
242 }
243
244
245
246
247
248
249
250
251
252 private static int orthant(double[] vector) {
253 int orthant = 0;
254 for (int i = 0; i < vector.length; i++) {
255 if (vector[i] < 0) {
256 orthant |= 1 << i;
257 }
258 }
259 return orthant;
260 }
261
262
263
264
265
266
267 @Test
268 void checkVolumeFunctions() {
269 final double[] radii = {0, 0.1, 0.25, 0.5, 0.75, 1.0};
270 for (int n = 1; n <= 6; n++) {
271 final DoubleUnaryOperator volume = createVolumeFunction(n);
272 final DoubleUnaryOperator radius = createRadiusFunction(n);
273 for (final double r : radii) {
274 Assertions.assertEquals(r, radius.applyAsDouble(volume.applyAsDouble(r)), 1e-10);
275 }
276 }
277 }
278
279
280
281
282
283
284
285
286
287 private static DoubleUnaryOperator createVolumeFunction(final int dimension) {
288 if (dimension == 1) {
289 return r -> r * 2;
290 } else if (dimension == 2) {
291 return r -> Math.PI * r * r;
292 } else if (dimension == 3) {
293 final double factor = 4 * Math.PI / 3;
294 return r -> factor * Math.pow(r, 3);
295 } else if (dimension == 4) {
296 final double factor = Math.PI * Math.PI / 2;
297 return r -> factor * Math.pow(r, 4);
298 } else if (dimension == 5) {
299 final double factor = 8 * Math.PI * Math.PI / 15;
300 return r -> factor * Math.pow(r, 5);
301 } else if (dimension == 6) {
302 final double factor = Math.pow(Math.PI, 3) / 6;
303 return r -> factor * Math.pow(r, 6);
304 }
305 throw new IllegalStateException("Unsupported dimension: " + dimension);
306 }
307
308
309
310
311
312
313
314
315
316 private static DoubleUnaryOperator createRadiusFunction(final int dimension) {
317 if (dimension == 1) {
318 return v -> v * 0.5;
319 } else if (dimension == 2) {
320 return v -> Math.sqrt(v / Math.PI);
321 } else if (dimension == 3) {
322 final double factor = 3.0 / (4 * Math.PI);
323 return v -> Math.cbrt(v * factor);
324 } else if (dimension == 4) {
325 final double factor = 2.0 / (Math.PI * Math.PI);
326 return v -> Math.pow(v * factor, 0.25);
327 } else if (dimension == 5) {
328 final double factor = 15.0 / (8 * Math.PI * Math.PI);
329 return v -> Math.pow(v * factor, 0.2);
330 } else if (dimension == 6) {
331 final double factor = 6.0 / Math.pow(Math.PI, 3);
332 return v -> Math.pow(v * factor, 1.0 / 6);
333 }
334 throw new IllegalStateException("Unsupported dimension: " + dimension);
335 }
336 }