1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.core.util;
18
19 import java.util.Arrays;
20 import java.util.Spliterator;
21 import java.util.SplittableRandom;
22 import java.util.concurrent.ExecutionException;
23 import java.util.concurrent.ForkJoinPool;
24 import java.util.concurrent.atomic.AtomicLong;
25 import java.util.function.Consumer;
26 import java.util.function.IntConsumer;
27 import java.util.function.LongConsumer;
28 import java.util.function.Supplier;
29 import java.util.stream.LongStream;
30 import org.apache.commons.math3.stat.inference.ChiSquareTest;
31 import org.apache.commons.rng.SplittableUniformRandomProvider;
32 import org.apache.commons.rng.UniformRandomProvider;
33 import org.apache.commons.rng.core.util.RandomStreams.SeededObjectFactory;
34 import org.junit.jupiter.api.Assertions;
35 import org.junit.jupiter.api.Test;
36 import org.junit.jupiter.params.ParameterizedTest;
37 import org.junit.jupiter.params.provider.CsvSource;
38 import org.junit.jupiter.params.provider.ValueSource;
39
40
41
42
43 class RandomStreamsTest {
44
45 private static final int CHAR_BITS = 4;
46
47
48
49
50
51 private static class SequenceGenerator implements SplittableUniformRandomProvider {
52
53 private final AtomicLong value;
54
55
56
57
58 SequenceGenerator(long seed) {
59 value = new AtomicLong(seed);
60 }
61
62
63
64
65 SequenceGenerator(AtomicLong value) {
66 this.value = value;
67 }
68
69 @Override
70 public long nextLong() {
71 return value.getAndIncrement();
72 }
73
74 @Override
75 public SplittableUniformRandomProvider split(UniformRandomProvider source) {
76
77 return new SequenceGenerator(value);
78 }
79 }
80
81
82
83
84
85
86
87
88 private static class SeedDecoder implements Consumer<Long>, LongConsumer {
89
90 private final long initial;
91
92 private long seed;
93
94 private long position = -1;
95
96
97
98
99 SeedDecoder(long initial) {
100 this.initial = initial;
101 }
102
103 @Override
104 public void accept(long value) {
105 if (position < 0) {
106
107 seed = initial;
108 long mask = -1;
109 while (seed != 0 && (value & mask) != seed) {
110 seed <<= CHAR_BITS;
111 mask <<= CHAR_BITS;
112 }
113 if (seed == 0) {
114 Assertions.fail(() -> String.format("Failed to decode position from %s using seed %s",
115 Long.toBinaryString(value), Long.toBinaryString(initial)));
116 }
117
118 position = value & ~seed;
119 } else {
120
121 final long expected = position + 1;
122
123 while (seed != 0 && Long.compareUnsigned(Long.lowestOneBit(seed), expected) <= 0) {
124 seed <<= CHAR_BITS;
125 }
126 Assertions.assertEquals(expected | seed, value);
127 position = expected;
128 }
129 }
130
131 @Override
132 public void accept(Long t) {
133 accept(t.longValue());
134 }
135
136
137
138
139 void reset() {
140 position = -1;
141 }
142 }
143
144
145
146
147
148
149
150
151
152
153
154
155
156 @ParameterizedTest
157 @ValueSource(longs = {1628346812812L})
158 void testCreateSeed(long seed) {
159 final UniformRandomProvider rng = new SplittableRandom(seed)::nextLong;
160
161
162 final int m = (1 << CHAR_BITS) - 1;
163
164 final int n = (int) Math.ceil((Long.SIZE - CHAR_BITS) / CHAR_BITS);
165 final int[][] h = new int[m + 1][m + 1];
166 final int samples = 1 << 16;
167 for (int i = 0; i < samples; i++) {
168 long s = RandomStreams.createSeed(rng);
169 final int unique = (int) (s & m);
170 for (int j = 0; j < n; j++) {
171 s >>>= CHAR_BITS;
172 h[unique][(int) (s & m)]++;
173 }
174 }
175
176
177 final int[] empty = new int[m + 1];
178 for (int i = 0; i <= m; i += 2) {
179 Assertions.assertArrayEquals(empty, h[i], "Even histograms should be empty");
180 }
181
182
183 for (int i = 1; i <= m; i += 2) {
184 Assertions.assertEquals(0, h[i][i]);
185 }
186
187
188 final long[] sum = new long[(m + 1) / 2];
189 for (int i = 1; i <= m; i += 2) {
190 final long total = Arrays.stream(h[i]).sum();
191 Assertions.assertEquals(0, total % n, "Samples should be a multiple of the number of characters");
192 sum[i / 2] = total / n;
193 }
194
195 assertChiSquare(sum, () -> "Unique character distribution");
196
197
198
199
200
201
202 Assertions.assertEquals(0, Long.SIZE % CHAR_BITS, "Character distribution cannot be tested as uniform");
203 for (int i = 1; i <= m; i += 2) {
204 final long[] obs = Arrays.stream(h[i]).filter(c -> c != 0).asLongStream().toArray();
205 final int c = i;
206 assertChiSquare(obs, () -> "Other character distribution for unique character " + c);
207 }
208 }
209
210
211
212
213
214
215
216 private static void assertChiSquare(long[] obs, Supplier<String> msg) {
217 final ChiSquareTest t = new ChiSquareTest();
218 final double alpha = 0.001;
219 final double[] expected = new double[obs.length];
220 Arrays.fill(expected, 1.0 / obs.length);
221 final double p = t.chiSquareTest(expected, obs);
222 Assertions.assertFalse(p < alpha, () -> String.format("%s: chi2 p-value: %s < %s", msg.get(), p, alpha));
223 }
224
225 @ParameterizedTest
226 @ValueSource(longs = {-1, -2, Long.MIN_VALUE})
227 void testGenerateWithSeedInvalidStreamSizeThrows(long size) {
228 final SplittableUniformRandomProvider source = new SequenceGenerator(0);
229 final SeededObjectFactory<Long> factory = (s, r) -> Long.valueOf(s);
230 final IllegalArgumentException ex1 = Assertions.assertThrows(IllegalArgumentException.class,
231 () -> RandomStreams.generateWithSeed(size, source, factory));
232
233 final IllegalArgumentException ex2 = Assertions.assertThrows(IllegalArgumentException.class,
234 () -> source.ints(size));
235 Assertions.assertEquals(ex2.getMessage(), ex1.getMessage(), "Inconsistent exception message");
236 }
237
238 @Test
239 void testGenerateWithSeedNullArgumentThrows() {
240 final long size = 10;
241 final SplittableUniformRandomProvider source = new SequenceGenerator(0);
242 final SeededObjectFactory<Long> factory = (s, r) -> Long.valueOf(s);
243 Assertions.assertThrows(NullPointerException.class,
244 () -> RandomStreams.generateWithSeed(size, null, factory));
245 Assertions.assertThrows(NullPointerException.class,
246 () -> RandomStreams.generateWithSeed(size, source, null));
247 }
248
249
250
251
252
253
254
255
256
257 @ParameterizedTest
258 @CsvSource({
259 "1, 23",
260 "4, 31",
261 "4, 3",
262 "8, 127",
263 })
264 void testGenerateWithSeed(int threads, long streamSize) throws InterruptedException, ExecutionException {
265
266 final SplittableUniformRandomProvider rng = new SplittableUniformRandomProvider() {
267 @Override
268 public long nextLong() {
269 return 1;
270 }
271
272 @Override
273 public SplittableUniformRandomProvider split(UniformRandomProvider source) {
274 return this;
275 }
276 };
277 Assertions.assertEquals(1, RandomStreams.createSeed(rng), "Unexpected seed value");
278
279
280 final SeededObjectFactory<Long> factory = (s, r) -> {
281 Assertions.assertSame(rng, r, "The source RNG is not used");
282 return Long.valueOf(s);
283 };
284
285
286 final ForkJoinPool threadPool = new ForkJoinPool(threads);
287 Long[] values;
288 try {
289 values = threadPool.submit(() ->
290 RandomStreams.generateWithSeed(streamSize, rng, factory).parallel().toArray(Long[]::new)).get();
291 } finally {
292 threadPool.shutdown();
293 }
294
295
296 final long[] actual = Arrays.stream(values).mapToLong(Long::longValue)
297 .map(l -> l - Long.highestOneBit(l)).sorted().toArray();
298 final long[] expected = LongStream.range(0, streamSize).toArray();
299 Assertions.assertArrayEquals(expected, actual);
300 }
301
302 @Test
303 void testGenerateWithSeedSpliteratorThrows() {
304 final long size = 10;
305 final SplittableUniformRandomProvider source = new SequenceGenerator(0);
306 final SeededObjectFactory<Long> factory = (s, r) -> Long.valueOf(s);
307 final Spliterator<Long> s1 = RandomStreams.generateWithSeed(size, source, factory).spliterator();
308 final Consumer<Long> badAction = null;
309 final NullPointerException ex1 = Assertions.assertThrows(NullPointerException.class, () -> s1.tryAdvance(badAction), "tryAdvance");
310 final NullPointerException ex2 = Assertions.assertThrows(NullPointerException.class, () -> s1.forEachRemaining(badAction), "forEachRemaining");
311
312 final Spliterator.OfInt s2 = source.ints().spliterator();
313 final NullPointerException ex3 = Assertions.assertThrows(NullPointerException.class, () -> s2.tryAdvance((IntConsumer) null), "tryAdvance");
314 Assertions.assertEquals(ex3.getMessage(), ex1.getMessage(), "Inconsistent tryAdvance exception message");
315 Assertions.assertEquals(ex3.getMessage(), ex2.getMessage(), "Inconsistent forEachRemaining exception message");
316 }
317
318 @Test
319 void testGenerateWithSeedSpliterator() {
320
321
322 final long initial = RandomStreams.createSeed(new SplittableRandom()::nextLong);
323 final SplittableUniformRandomProvider rng = new SplittableUniformRandomProvider() {
324 @Override
325 public long nextLong() {
326 return initial;
327 }
328
329 @Override
330 public SplittableUniformRandomProvider split(UniformRandomProvider source) {
331 return this;
332 }
333 };
334 Assertions.assertEquals(initial, RandomStreams.createSeed(rng), "Unexpected seed value");
335
336
337 final SeededObjectFactory<Long> factory = (s, r) -> {
338 Assertions.assertSame(rng, r, "The source RNG is not used");
339 return Long.valueOf(s);
340 };
341
342
343
344 final long size = 41;
345 Spliterator<Long> s1 = RandomStreams.generateWithSeed(size, rng, factory).spliterator();
346 Assertions.assertEquals(size, s1.estimateSize());
347 Assertions.assertTrue(s1.hasCharacteristics(Spliterator.SIZED | Spliterator.SUBSIZED | Spliterator.IMMUTABLE),
348 "Invalid characteristics");
349 final Spliterator<Long> s2 = s1.trySplit();
350 final Spliterator<Long> s3 = s1.trySplit();
351 final Spliterator<Long> s4 = s2.trySplit();
352 Assertions.assertEquals(size, s1.estimateSize() + s2.estimateSize() + s3.estimateSize() + s4.estimateSize());
353
354
355 while (s1.estimateSize() > 1) {
356 final long currentSize = s1.estimateSize();
357 final Spliterator<Long> other = s1.trySplit();
358 Assertions.assertEquals(currentSize, s1.estimateSize() + other.estimateSize());
359 s1 = other;
360 }
361 Assertions.assertNull(s1.trySplit(), "Cannot split when size <= 1");
362
363
364
365 final SeedDecoder action = new SeedDecoder(initial);
366
367
368 for (long newSize = s2.estimateSize(); newSize-- > 0;) {
369 Assertions.assertTrue(s2.tryAdvance(action));
370 Assertions.assertEquals(newSize, s2.estimateSize(), "s2 size estimate");
371 }
372 final Consumer<Long> throwIfCalled = r -> Assertions.fail("spliterator should be empty");
373 Assertions.assertFalse(s2.tryAdvance(throwIfCalled));
374 s2.forEachRemaining(throwIfCalled);
375
376
377 action.reset();
378 s3.forEachRemaining(action);
379 Assertions.assertEquals(0, s3.estimateSize());
380 s3.forEachRemaining(throwIfCalled);
381
382
383 final IllegalStateException ex = new IllegalStateException();
384 final Consumer<Long> badAction = r -> {
385 throw ex;
386 };
387 final long currentSize = s4.estimateSize();
388 Assertions.assertTrue(currentSize > 1, "Spliterator requires more elements to test advance");
389 Assertions.assertSame(ex, Assertions.assertThrows(IllegalStateException.class, () -> s4.tryAdvance(badAction)));
390 Assertions.assertEquals(currentSize - 1, s4.estimateSize(), "Spliterator should be advanced even when action throws");
391
392 Assertions.assertSame(ex, Assertions.assertThrows(IllegalStateException.class, () -> s4.forEachRemaining(badAction)));
393 Assertions.assertEquals(0, s4.estimateSize(), "Spliterator should be finished even when action throws");
394 s4.forEachRemaining(throwIfCalled);
395 }
396
397
398
399
400
401
402 @Test
403 void testLargeStreamSize() {
404
405
406 final long initial = RandomStreams.createSeed(new SplittableRandom()::nextLong);
407 final SplittableUniformRandomProvider rng = new SplittableUniformRandomProvider() {
408 @Override
409 public long nextLong() {
410 return initial;
411 }
412
413 @Override
414 public SplittableUniformRandomProvider split(UniformRandomProvider source) {
415 return this;
416 }
417 };
418 Assertions.assertEquals(initial, RandomStreams.createSeed(rng), "Unexpected seed value");
419
420
421 final SeededObjectFactory<Long> factory = (s, r) -> {
422 Assertions.assertSame(rng, r, "The source RNG is not used");
423 return Long.valueOf(s);
424 };
425
426 final Spliterator<Long> s = RandomStreams.generateWithSeed(1L << 62, rng, factory).spliterator();
427
428
429 final Spliterator<Long> s1 = s.trySplit();
430
431
432 final SeedDecoder action = new SeedDecoder(initial);
433 long size = s1.estimateSize();
434 for (int i = 1; i <= 5; i++) {
435 Assertions.assertTrue(s1.tryAdvance(action));
436 Assertions.assertEquals(size - i, s1.estimateSize(), "s1 size estimate");
437 }
438
439
440
441 final long[] expected = {0};
442 s.tryAdvance(seed -> expected[0] = seed);
443 size = s.estimateSize();
444 for (int i = 1; i <= 5; i++) {
445 Assertions.assertTrue(s.tryAdvance(seed -> Assertions.assertEquals(++expected[0], seed)));
446 Assertions.assertEquals(size - i, s.estimateSize(), "s size estimate");
447 }
448 }
449 }