View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.rng.core.util;
18  
19  import java.util.Arrays;
20  import java.util.Spliterator;
21  import java.util.SplittableRandom;
22  import java.util.concurrent.ExecutionException;
23  import java.util.concurrent.ForkJoinPool;
24  import java.util.concurrent.atomic.AtomicLong;
25  import java.util.function.Consumer;
26  import java.util.function.IntConsumer;
27  import java.util.function.LongConsumer;
28  import java.util.function.Supplier;
29  import java.util.stream.LongStream;
30  import org.apache.commons.math3.stat.inference.ChiSquareTest;
31  import org.apache.commons.rng.SplittableUniformRandomProvider;
32  import org.apache.commons.rng.UniformRandomProvider;
33  import org.apache.commons.rng.core.util.RandomStreams.SeededObjectFactory;
34  import org.junit.jupiter.api.Assertions;
35  import org.junit.jupiter.api.Test;
36  import org.junit.jupiter.params.ParameterizedTest;
37  import org.junit.jupiter.params.provider.CsvSource;
38  import org.junit.jupiter.params.provider.ValueSource;
39  
40  /**
41   * Tests for {@link RandomStreams}.
42   */
43  class RandomStreamsTest {
44      /** The size in bits of the seed characters. */
45      private static final int CHAR_BITS = 4;
46  
47      /**
48       * Class for outputting a unique sequence from the nextLong() method even under
49       * recursive splitting. Splitting creates a new instance.
50       */
51      private static class SequenceGenerator implements SplittableUniformRandomProvider {
52          /** The value for nextLong. */
53          private final AtomicLong value;
54  
55          /**
56           * @param seed Sequence seed value.
57           */
58          SequenceGenerator(long seed) {
59              value = new AtomicLong(seed);
60          }
61  
62          /**
63           * @param value The value for nextLong.
64           */
65          SequenceGenerator(AtomicLong value) {
66              this.value = value;
67          }
68  
69          @Override
70          public long nextLong() {
71              return value.getAndIncrement();
72          }
73  
74          @Override
75          public SplittableUniformRandomProvider split(UniformRandomProvider source) {
76              // Ignore the source (use of the source is optional)
77              return new SequenceGenerator(value);
78          }
79      }
80  
81      /**
82       * Class for decoding the combined seed ((seed << shift) | position).
83       * Requires the unshifted seed. The shift is assumed to be a multiple of 4.
84       * The first call to the consumer will extract the current position.
85       * Further calls will compare the value with the predicted value using
86       * the last known position.
87       */
88      private static class SeedDecoder implements Consumer<Long>, LongConsumer {
89          /** The initial (unshifted) seed. */
90          private final long initial;
91          /** The current shifted seed. */
92          private long seed;
93          /** The last known position. */
94          private long position = -1;
95  
96          /**
97           * @param initial Unshifted seed value.
98           */
99          SeedDecoder(long initial) {
100             this.initial = initial;
101         }
102 
103         @Override
104         public void accept(long value) {
105             if (position < 0) {
106                 // Search for the initial seed value
107                 seed = initial;
108                 long mask = -1;
109                 while (seed != 0 && (value & mask) != seed) {
110                     seed <<= CHAR_BITS;
111                     mask <<= CHAR_BITS;
112                 }
113                 if (seed == 0) {
114                     Assertions.fail(() -> String.format("Failed to decode position from %s using seed %s",
115                         Long.toBinaryString(value), Long.toBinaryString(initial)));
116                 }
117                 // Remove the seed contribution leaving the position
118                 position = value & ~seed;
119             } else {
120                 // Predict
121                 final long expected = position + 1;
122                 //seed = initial;
123                 while (seed != 0 && Long.compareUnsigned(Long.lowestOneBit(seed), expected) <= 0) {
124                     seed <<= CHAR_BITS;
125                 }
126                 Assertions.assertEquals(expected | seed, value);
127                 position = expected;
128             }
129         }
130 
131         @Override
132         public void accept(Long t) {
133             accept(t.longValue());
134         }
135 
136         /**
137          * Reset the decoder.
138          */
139         void reset() {
140             position = -1;
141         }
142     }
143 
144     /**
145      * Test the seed has the required properties:
146      * <ul>
147      * <li>Test the seed has an odd character in the least significant position
148      * <li>Test the remaining characters in the seed do not match this character
149      * <li>Test the distribution of characters is uniform
150      * <ul>
151      *
152      * <p>The test assumes the character size is 4-bits.
153      *
154      * @param seed the seed
155      */
156     @ParameterizedTest
157     @ValueSource(longs = {1628346812812L})
158     void testCreateSeed(long seed) {
159         final UniformRandomProvider rng = new SplittableRandom(seed)::nextLong;
160 
161         // Histogram the distribution for each unique 4-bit character
162         final int m = (1 << CHAR_BITS) - 1;
163         // Number of remaining characters
164         final int n = (int) Math.ceil((Long.SIZE - CHAR_BITS) / CHAR_BITS);
165         final int[][] h = new int[m + 1][m + 1];
166         final int samples = 1 << 16;
167         for (int i = 0; i < samples; i++) {
168             long s = RandomStreams.createSeed(rng);
169             final int unique = (int) (s & m);
170             for (int j = 0; j < n; j++) {
171                 s >>>= CHAR_BITS;
172                 h[unique][(int) (s & m)]++;
173             }
174         }
175 
176         // Test unique characters are always odd.
177         final int[] empty = new int[m + 1];
178         for (int i = 0; i <= m; i += 2) {
179             Assertions.assertArrayEquals(empty, h[i], "Even histograms should be empty");
180         }
181 
182         // Test unique characters are not repeated
183         for (int i = 1; i <= m; i += 2) {
184             Assertions.assertEquals(0, h[i][i]);
185         }
186 
187         // Chi-square test the distribution of unique characters
188         final long[] sum = new long[(m + 1) / 2];
189         for (int i = 1; i <= m; i += 2) {
190             final long total = Arrays.stream(h[i]).sum();
191             Assertions.assertEquals(0, total % n, "Samples should be a multiple of the number of characters");
192             sum[i / 2] = total / n;
193         }
194 
195         assertChiSquare(sum, () -> "Unique character distribution");
196 
197         // Chi-square test the distribution for each unique character.
198         // Note: This will fail if the characters do not evenly divide into 64.
199         // In that case the expected values are not uniform as the final
200         // character will be truncated and skew the expected values to lower characters.
201         // For simplicity this has not been accounted for as 4-bits evenly divides 64.
202         Assertions.assertEquals(0, Long.SIZE % CHAR_BITS, "Character distribution cannot be tested as uniform");
203         for (int i = 1; i <= m; i += 2) {
204             final long[] obs = Arrays.stream(h[i]).filter(c -> c != 0).asLongStream().toArray();
205             final int c = i;
206             assertChiSquare(obs, () -> "Other character distribution for unique character " + c);
207         }
208     }
209 
210     /**
211      * Assert the observations are uniform using a chi-square test.
212      *
213      * @param obs Observations.
214      * @param msg Failure message prefix.
215      */
216     private static void assertChiSquare(long[] obs, Supplier<String> msg) {
217         final ChiSquareTest t = new ChiSquareTest();
218         final double alpha = 0.001;
219         final double[] expected = new double[obs.length];
220         Arrays.fill(expected, 1.0 / obs.length);
221         final double p = t.chiSquareTest(expected, obs);
222         Assertions.assertFalse(p < alpha, () -> String.format("%s: chi2 p-value: %s < %s", msg.get(), p, alpha));
223     }
224 
225     @ParameterizedTest
226     @ValueSource(longs = {-1, -2, Long.MIN_VALUE})
227     void testGenerateWithSeedInvalidStreamSizeThrows(long size) {
228         final SplittableUniformRandomProvider source = new SequenceGenerator(0);
229         final SeededObjectFactory<Long> factory = (s, r) -> Long.valueOf(s);
230         final IllegalArgumentException ex1 = Assertions.assertThrows(IllegalArgumentException.class,
231             () -> RandomStreams.generateWithSeed(size, source, factory));
232         // Check the exception method is consistent with UniformRandomProvider stream methods
233         final IllegalArgumentException ex2 = Assertions.assertThrows(IllegalArgumentException.class,
234             () -> source.ints(size));
235         Assertions.assertEquals(ex2.getMessage(), ex1.getMessage(), "Inconsistent exception message");
236     }
237 
238     @Test
239     void testGenerateWithSeedNullArgumentThrows() {
240         final long size = 10;
241         final SplittableUniformRandomProvider source = new SequenceGenerator(0);
242         final SeededObjectFactory<Long> factory = (s, r) -> Long.valueOf(s);
243         Assertions.assertThrows(NullPointerException.class,
244             () -> RandomStreams.generateWithSeed(size, null, factory));
245         Assertions.assertThrows(NullPointerException.class,
246             () -> RandomStreams.generateWithSeed(size, source, null));
247     }
248 
249     /**
250      * Test that the seed passed to the factory is ((seed << shift) | position).
251      * This is done by creating an initial seed value of 1. When removed the
252      * remaining values should be a sequence.
253      *
254      * @param threads Number of threads.
255      * @param streamSize Stream size.
256      */
257     @ParameterizedTest
258     @CsvSource({
259         "1, 23",
260         "4, 31",
261         "4, 3",
262         "8, 127",
263     })
264     void testGenerateWithSeed(int threads, long streamSize) throws InterruptedException, ExecutionException {
265         // Provide a generator that results in the seed being set as 1.
266         final SplittableUniformRandomProvider rng = new SplittableUniformRandomProvider() {
267             @Override
268             public long nextLong() {
269                 return 1;
270             }
271 
272             @Override
273             public SplittableUniformRandomProvider split(UniformRandomProvider source) {
274                 return this;
275             }
276         };
277         Assertions.assertEquals(1, RandomStreams.createSeed(rng), "Unexpected seed value");
278 
279         // Create a factory that will return the seed passed to the factory
280         final SeededObjectFactory<Long> factory = (s, r) -> {
281             Assertions.assertSame(rng, r, "The source RNG is not used");
282             return Long.valueOf(s);
283         };
284 
285         // Stream in a custom pool
286         final ForkJoinPool threadPool = new ForkJoinPool(threads);
287         Long[] values;
288         try {
289             values = threadPool.submit(() ->
290                 RandomStreams.generateWithSeed(streamSize, rng, factory).parallel().toArray(Long[]::new)).get();
291         } finally {
292             threadPool.shutdown();
293         }
294 
295         // Remove the highest 1 bit from each long. The rest should be a sequence.
296         final long[] actual = Arrays.stream(values).mapToLong(Long::longValue)
297                 .map(l -> l - Long.highestOneBit(l)).sorted().toArray();
298         final long[] expected = LongStream.range(0, streamSize).toArray();
299         Assertions.assertArrayEquals(expected, actual);
300     }
301 
302     @Test
303     void testGenerateWithSeedSpliteratorThrows() {
304         final long size = 10;
305         final SplittableUniformRandomProvider source = new SequenceGenerator(0);
306         final SeededObjectFactory<Long> factory = (s, r) -> Long.valueOf(s);
307         final Spliterator<Long> s1 = RandomStreams.generateWithSeed(size, source, factory).spliterator();
308         final Consumer<Long> badAction = null;
309         final NullPointerException ex1 = Assertions.assertThrows(NullPointerException.class, () -> s1.tryAdvance(badAction), "tryAdvance");
310         final NullPointerException ex2 = Assertions.assertThrows(NullPointerException.class, () -> s1.forEachRemaining(badAction), "forEachRemaining");
311         // Check the exception method is consistent with UniformRandomProvider stream methods
312         final Spliterator.OfInt s2 = source.ints().spliterator();
313         final NullPointerException ex3 = Assertions.assertThrows(NullPointerException.class, () -> s2.tryAdvance((IntConsumer) null), "tryAdvance");
314         Assertions.assertEquals(ex3.getMessage(), ex1.getMessage(), "Inconsistent tryAdvance exception message");
315         Assertions.assertEquals(ex3.getMessage(), ex2.getMessage(), "Inconsistent forEachRemaining exception message");
316     }
317 
318     @Test
319     void testGenerateWithSeedSpliterator() {
320         // Create an initial seed value. This should not be modified by the algorithm
321         // when generating a 'new' seed from the RNG.
322         final long initial = RandomStreams.createSeed(new SplittableRandom()::nextLong);
323         final SplittableUniformRandomProvider rng = new SplittableUniformRandomProvider() {
324             @Override
325             public long nextLong() {
326                 return initial;
327             }
328 
329             @Override
330             public SplittableUniformRandomProvider split(UniformRandomProvider source) {
331                 return this;
332             }
333         };
334         Assertions.assertEquals(initial, RandomStreams.createSeed(rng), "Unexpected seed value");
335 
336         // Create a factory that will return the seed passed to the factory
337         final SeededObjectFactory<Long> factory = (s, r) -> {
338             Assertions.assertSame(rng, r, "The source RNG is not used");
339             return Long.valueOf(s);
340         };
341 
342         // Split a large spliterator into four smaller ones;
343         // each is used to test different functionality
344         final long size = 41;
345         Spliterator<Long> s1 = RandomStreams.generateWithSeed(size, rng, factory).spliterator();
346         Assertions.assertEquals(size, s1.estimateSize());
347         Assertions.assertTrue(s1.hasCharacteristics(Spliterator.SIZED | Spliterator.SUBSIZED | Spliterator.IMMUTABLE),
348             "Invalid characteristics");
349         final Spliterator<Long> s2 = s1.trySplit();
350         final Spliterator<Long> s3 = s1.trySplit();
351         final Spliterator<Long> s4 = s2.trySplit();
352         Assertions.assertEquals(size, s1.estimateSize() + s2.estimateSize() + s3.estimateSize() + s4.estimateSize());
353 
354         // s1. Test cannot split indefinitely
355         while (s1.estimateSize() > 1) {
356             final long currentSize = s1.estimateSize();
357             final Spliterator<Long> other = s1.trySplit();
358             Assertions.assertEquals(currentSize, s1.estimateSize() + other.estimateSize());
359             s1 = other;
360         }
361         Assertions.assertNull(s1.trySplit(), "Cannot split when size <= 1");
362 
363         // Create an action that will decode the shift and position using the
364         // known initial seed. This can be used to predict and assert the next value.
365         final SeedDecoder action = new SeedDecoder(initial);
366 
367         // s2. Test advance
368         for (long newSize = s2.estimateSize(); newSize-- > 0;) {
369             Assertions.assertTrue(s2.tryAdvance(action));
370             Assertions.assertEquals(newSize, s2.estimateSize(), "s2 size estimate");
371         }
372         final Consumer<Long> throwIfCalled = r -> Assertions.fail("spliterator should be empty");
373         Assertions.assertFalse(s2.tryAdvance(throwIfCalled));
374         s2.forEachRemaining(throwIfCalled);
375 
376         // s3. Test forEachRemaining
377         action.reset();
378         s3.forEachRemaining(action);
379         Assertions.assertEquals(0, s3.estimateSize());
380         s3.forEachRemaining(throwIfCalled);
381 
382         // s4. Test tryAdvance and forEachRemaining when the action throws an exception
383         final IllegalStateException ex = new IllegalStateException();
384         final Consumer<Long> badAction = r -> {
385             throw ex;
386         };
387         final long currentSize = s4.estimateSize();
388         Assertions.assertTrue(currentSize > 1, "Spliterator requires more elements to test advance");
389         Assertions.assertSame(ex, Assertions.assertThrows(IllegalStateException.class, () -> s4.tryAdvance(badAction)));
390         Assertions.assertEquals(currentSize - 1, s4.estimateSize(), "Spliterator should be advanced even when action throws");
391 
392         Assertions.assertSame(ex, Assertions.assertThrows(IllegalStateException.class, () -> s4.forEachRemaining(badAction)));
393         Assertions.assertEquals(0, s4.estimateSize(), "Spliterator should be finished even when action throws");
394         s4.forEachRemaining(throwIfCalled);
395     }
396 
397     /**
398      * Test a very large stream size above 2<sup>60</sup>.
399      * In this case it is not possible to prepend a 4-bit character
400      * to the stream position. The seed passed to the factory will be the stream position.
401      */
402     @Test
403     void testLargeStreamSize() {
404         // Create an initial seed value. This should not be modified by the algorithm
405         // when generating a 'new' seed from the RNG.
406         final long initial = RandomStreams.createSeed(new SplittableRandom()::nextLong);
407         final SplittableUniformRandomProvider rng = new SplittableUniformRandomProvider() {
408             @Override
409             public long nextLong() {
410                 return initial;
411             }
412 
413             @Override
414             public SplittableUniformRandomProvider split(UniformRandomProvider source) {
415                 return this;
416             }
417         };
418         Assertions.assertEquals(initial, RandomStreams.createSeed(rng), "Unexpected seed value");
419 
420         // Create a factory that will return the seed passed to the factory
421         final SeededObjectFactory<Long> factory = (s, r) -> {
422             Assertions.assertSame(rng, r, "The source RNG is not used");
423             return Long.valueOf(s);
424         };
425 
426         final Spliterator<Long> s = RandomStreams.generateWithSeed(1L << 62, rng, factory).spliterator();
427 
428         // Split uses a divide-by-two approach. The child uses the smaller half.
429         final Spliterator<Long> s1 = s.trySplit();
430 
431         // Lower half. The next position can be predicted using the decoder.
432         final SeedDecoder action = new SeedDecoder(initial);
433         long size = s1.estimateSize();
434         for (int i = 1; i <= 5; i++) {
435             Assertions.assertTrue(s1.tryAdvance(action));
436             Assertions.assertEquals(size - i, s1.estimateSize(), "s1 size estimate");
437         }
438 
439         // Upper half. This should be just the stream position which we can
440         // collect with a call to advance.
441         final long[] expected = {0};
442         s.tryAdvance(seed -> expected[0] = seed);
443         size = s.estimateSize();
444         for (int i = 1; i <= 5; i++) {
445             Assertions.assertTrue(s.tryAdvance(seed -> Assertions.assertEquals(++expected[0], seed)));
446             Assertions.assertEquals(size - i, s.estimateSize(), "s size estimate");
447         }
448     }
449 }