View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.rng.sampling.distribution;
18  
19  import java.util.Arrays;
20  import java.util.function.DoubleUnaryOperator;
21  import java.util.stream.Stream;
22  import org.apache.commons.math3.stat.descriptive.moment.Mean;
23  import org.apache.commons.math3.stat.inference.ChiSquareTest;
24  import org.apache.commons.rng.UniformRandomProvider;
25  import org.apache.commons.rng.sampling.RandomAssert;
26  import org.junit.jupiter.api.Assertions;
27  import org.junit.jupiter.api.Assumptions;
28  import org.junit.jupiter.api.Test;
29  import org.junit.jupiter.params.ParameterizedTest;
30  import org.junit.jupiter.params.provider.MethodSource;
31  import org.junit.jupiter.params.provider.ValueSource;
32  
33  /**
34   * Test for the {@link FastLoadedDiceRollerDiscreteSampler}.
35   */
36  class FastLoadedDiceRollerDiscreteSamplerTest {
37      /**
38       * Creates the sampler.
39       *
40       * @param frequencies Observed frequencies.
41       * @return the FLDR sampler
42       */
43      private static SharedStateDiscreteSampler createSampler(long... frequencies) {
44          final UniformRandomProvider rng = RandomAssert.createRNG();
45          return FastLoadedDiceRollerDiscreteSampler.of(rng, frequencies);
46      }
47  
48      /**
49       * Creates the sampler.
50       *
51       * @param weights Weights.
52       * @return the FLDR sampler
53       */
54      private static SharedStateDiscreteSampler createSampler(double... weights) {
55          final UniformRandomProvider rng = RandomAssert.createRNG();
56          return FastLoadedDiceRollerDiscreteSampler.of(rng, weights);
57      }
58  
59      /**
60       * Return a stream of invalid frequencies for a discrete distribution.
61       *
62       * @return the stream of invalid frequencies
63       */
64      static Stream<long[]> testFactoryConstructorFrequencies() {
65          return Stream.of(
66              // Null or empty
67              (long[]) null,
68              new long[0],
69              // Negative
70              new long[] {-1, 2, 3},
71              new long[] {1, -2, 3},
72              new long[] {1, 2, -3},
73              // Overflow of sum
74              new long[] {Long.MAX_VALUE, Long.MAX_VALUE},
75              // x+x+2 == 0
76              new long[] {Long.MAX_VALUE, Long.MAX_VALUE, 2},
77              // x+x+x == x - 2 (i.e. positive)
78              new long[] {Long.MAX_VALUE, Long.MAX_VALUE, Long.MAX_VALUE},
79              // Zero sum
80              new long[1],
81              new long[4]
82          );
83      }
84  
85      @ParameterizedTest
86      @MethodSource
87      void testFactoryConstructorFrequencies(long[] frequencies) {
88          Assertions.assertThrows(IllegalArgumentException.class, () -> createSampler(frequencies));
89      }
90  
91      /**
92       * Return a stream of invalid weights for a discrete distribution.
93       *
94       * @return the stream of invalid weights
95       */
96      static Stream<double[]> testFactoryConstructorWeights() {
97          return Stream.of(
98              // Null or empty
99              (double[]) null,
100             new double[0],
101             // Negative, infinite or NaN
102             new double[] {-1, 2, 3},
103             new double[] {1, -2, 3},
104             new double[] {1, 2, -3},
105             new double[] {Double.POSITIVE_INFINITY, 2, 3},
106             new double[] {1, Double.POSITIVE_INFINITY, 3},
107             new double[] {1, 2, Double.POSITIVE_INFINITY},
108             new double[] {Double.NaN, 2, 3},
109             new double[] {1, Double.NaN, 3},
110             new double[] {1, 2, Double.NaN},
111             // Zero sum
112             new double[1],
113             new double[4]
114         );
115     }
116 
117     @ParameterizedTest
118     @MethodSource
119     void testFactoryConstructorWeights(double[] weights) {
120         Assertions.assertThrows(IllegalArgumentException.class, () -> createSampler(weights));
121     }
122 
123     @Test
124     void testToString() {
125         for (final long[] observed : new long[][] {{42}, {1, 2, 3}}) {
126             final SharedStateDiscreteSampler sampler = createSampler(observed);
127             Assertions.assertTrue(sampler.toString().toLowerCase().contains("fast loaded dice roller"));
128         }
129     }
130 
131     @Test
132     void testSingleCategory() {
133         final int n = 13;
134         final int[] expected = new int[n];
135         Assertions.assertArrayEquals(expected, createSampler(42).samples(n).toArray());
136         Assertions.assertArrayEquals(expected, createSampler(0.55).samples(n).toArray());
137     }
138 
139     @Test
140     void testSingleFrequency() {
141         final long[] frequencies = new long[5];
142         final int category = 2;
143         frequencies[category] = 1;
144         final SharedStateDiscreteSampler sampler = createSampler(frequencies);
145         final int n = 7;
146         final int[] expected = new int[n];
147         Arrays.fill(expected, category);
148         Assertions.assertArrayEquals(expected, sampler.samples(n).toArray());
149     }
150 
151     @Test
152     void testSingleWeight() {
153         final double[] weights = new double[5];
154         final int category = 3;
155         weights[category] = 1.5;
156         final SharedStateDiscreteSampler sampler = createSampler(weights);
157         final int n = 6;
158         final int[] expected = new int[n];
159         Arrays.fill(expected, category);
160         Assertions.assertArrayEquals(expected, sampler.samples(n).toArray());
161     }
162 
163     @Test
164     void testIndexOfNonZero() {
165         Assertions.assertThrows(IllegalStateException.class,
166             () -> FastLoadedDiceRollerDiscreteSampler.indexOfNonZero(new long[3]));
167         final long[] data = new long[3];
168         for (int i = 0; i < data.length; i++) {
169             data[i] = 13;
170             Assertions.assertEquals(i, FastLoadedDiceRollerDiscreteSampler.indexOfNonZero(data));
171             data[i] = 0;
172         }
173     }
174 
175     @ParameterizedTest
176     @ValueSource(longs = {0, 1, -1, Integer.MAX_VALUE, 1L << 34})
177     void testCheckArraySize(long size) {
178         // This is the same value as the sampler
179         final int max = Integer.MAX_VALUE - 8;
180         // Note: The method does not test for negatives.
181         // This is not required when validating a positive int multiplied by another positive int.
182         if (size > max) {
183             Assertions.assertThrows(IllegalArgumentException.class,
184                 () -> FastLoadedDiceRollerDiscreteSampler.checkArraySize(size));
185         } else {
186             Assertions.assertEquals((int) size, FastLoadedDiceRollerDiscreteSampler.checkArraySize(size));
187         }
188     }
189 
190     /**
191      * Return a stream of expected frequencies for a discrete distribution.
192      *
193      * @return the stream of expected frequencies
194      */
195     static Stream<long[]> testSamplesFrequencies() {
196         return Stream.of(
197             // Single category
198             new long[] {0, 0, 42, 0, 0},
199             // Sum to a power of 2
200             new long[] {1, 1, 2, 3, 1},
201             new long[] {0, 1, 1, 0, 2, 3, 1, 0},
202             // Do not sum to a power of 2
203             new long[] {1, 2, 3, 1, 3},
204             new long[] {1, 0, 2, 0, 3, 1, 3},
205             // Large frequencies
206             new long[] {5126734627834L, 213267384684832L, 126781236718L, 71289979621378L}
207         );
208     }
209 
210     /**
211      * Check the distribution of samples match the expected probabilities.
212      *
213      * @param expectedFrequencies Expected frequencies.
214      */
215     @ParameterizedTest
216     @MethodSource
217     void testSamplesFrequencies(long[] expectedFrequencies) {
218         final SharedStateDiscreteSampler sampler = createSampler(expectedFrequencies);
219         final int numberOfSamples = 10000;
220         final long[] samples = new long[expectedFrequencies.length];
221         sampler.samples(numberOfSamples).forEach(x -> samples[x]++);
222 
223         // Handle a test with some zero-probability observations by mapping them out
224         int mapSize = 0;
225         double sum = 0;
226         for (final double f : expectedFrequencies) {
227             if (f != 0) {
228                 mapSize++;
229                 sum += f;
230             }
231         }
232 
233         // Single category will break the Chi-square test
234         if (mapSize == 1) {
235             int index = 0;
236             while (index < expectedFrequencies.length) {
237                 if (expectedFrequencies[index] != 0) {
238                     break;
239                 }
240                 index++;
241             }
242             Assertions.assertEquals(numberOfSamples, samples[index], "Invalid single category samples");
243             return;
244         }
245 
246         final double[] expected = new double[mapSize];
247         final long[] observed = new long[mapSize];
248         for (int i = 0; i < expectedFrequencies.length; i++) {
249             if (expectedFrequencies[i] != 0) {
250                 --mapSize;
251                 expected[mapSize] = expectedFrequencies[i] / sum;
252                 observed[mapSize] = samples[i];
253             } else {
254                 Assertions.assertEquals(0, samples[i], "No samples expected from zero probability");
255             }
256         }
257 
258         final ChiSquareTest chiSquareTest = new ChiSquareTest();
259         // Pass if we cannot reject null hypothesis that the distributions are the same.
260         Assertions.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001));
261     }
262 
263     /**
264      * Return a stream of expected weights for a discrete distribution.
265      *
266      * @return the stream of expected weights
267      */
268     static Stream<double[]> testSamplesWeights() {
269         return Stream.of(
270             // Single category
271             new double[] {0, 0, 0.523, 0, 0},
272             // Sum to a power of 2
273             new double[] {0.125, 0.125, 0.25, 0.375, 0.125},
274             new double[] {0, 0.125, 0.125, 0.25, 0, 0.375, 0.125, 0},
275             // Do not sum to a power of 2
276             new double[] {0.1, 0.2, 0.3, 0.1, 0.3},
277             new double[] {0.1, 0, 0.2, 0, 0.3, 0.1, 0.3},
278             // Sub-normal numbers
279             new double[] {5 * Double.MIN_NORMAL, 2 * Double.MIN_NORMAL, 3 * Double.MIN_NORMAL, 9 * Double.MIN_NORMAL},
280             new double[] {2 * Double.MIN_NORMAL, Double.MIN_NORMAL, 0.5 * Double.MIN_NORMAL, 0.75 * Double.MIN_NORMAL},
281             new double[] {Double.MIN_VALUE, 2 * Double.MIN_VALUE, 3 * Double.MIN_VALUE, 7 * Double.MIN_VALUE},
282             // Large range of magnitude
283             new double[] {1.0, 2.0, Math.scalb(3.0, -32), Math.scalb(4.0, -65), 5.0},
284             new double[] {Math.scalb(1.0, 35), Math.scalb(2.0, 35), Math.scalb(3.0, -32), Math.scalb(4.0, -65), Math.scalb(5.0, 35)},
285             // Sum to infinite
286             new double[] {Double.MAX_VALUE, Double.MAX_VALUE, Double.MAX_VALUE / 2, Double.MAX_VALUE / 4}
287         );
288     }
289 
290     /**
291      * Check the distribution of samples match the expected weights.
292      *
293      * @param weights Category weights.
294      */
295     @ParameterizedTest
296     @MethodSource
297     void testSamplesWeights(double[] weights) {
298         final SharedStateDiscreteSampler sampler = createSampler(weights);
299         final int numberOfSamples = 10000;
300         final long[] samples = new long[weights.length];
301         sampler.samples(numberOfSamples).forEach(x -> samples[x]++);
302 
303         // Handle a test with some zero-probability observations by mapping them out
304         int mapSize = 0;
305         double sum = 0;
306         // Handle infinite sum using a rolling mean for normalisation
307         final Mean mean = new Mean();
308         for (final double w : weights) {
309             if (w != 0) {
310                 mapSize++;
311                 sum += w;
312                 mean.increment(w);
313             }
314         }
315 
316         // Single category will break the Chi-square test
317         if (mapSize == 1) {
318             int index = 0;
319             while (index < weights.length) {
320                 if (weights[index] != 0) {
321                     break;
322                 }
323                 index++;
324             }
325             Assertions.assertEquals(numberOfSamples, samples[index], "Invalid single category samples");
326             return;
327         }
328 
329         final double mu = mean.getResult();
330         final int n = mapSize;
331         final double s = sum;
332         final DoubleUnaryOperator normalise = Double.isInfinite(sum) ?
333             x -> (x / mu) * n :
334             x -> x / s;
335 
336         final double[] expected = new double[mapSize];
337         final long[] observed = new long[mapSize];
338         for (int i = 0; i < weights.length; i++) {
339             if (weights[i] != 0) {
340                 --mapSize;
341                 expected[mapSize] = normalise.applyAsDouble(weights[i]);
342                 observed[mapSize] = samples[i];
343             } else {
344                 Assertions.assertEquals(0, samples[i], "No samples expected from zero probability");
345             }
346         }
347 
348         final ChiSquareTest chiSquareTest = new ChiSquareTest();
349         // Pass if we cannot reject null hypothesis that the distributions are the same.
350         Assertions.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001));
351     }
352 
353     /**
354      * Return a stream of expected frequencies for a discrete distribution where the frequencies
355      * can be converted to a {@code double} without loss of precision.
356      *
357      * @return the stream of expected frequencies
358      */
359     static Stream<long[]> testSamplesWeightsMatchesFrequencies() {
360         // Reuse the same frequencies.
361         // Those that cannot be converted to a double are ignored by the test.
362         return testSamplesFrequencies();
363     }
364 
365     /**
366      * Check the distribution of samples when the frequencies can be converted to weights without
367      * loss of precision.
368      *
369      * @param frequencies Expected frequencies.
370      */
371     @ParameterizedTest
372     @MethodSource
373     void testSamplesWeightsMatchesFrequencies(long[] frequencies) {
374         final double[] weights = new double[frequencies.length];
375         for (int i = 0; i < frequencies.length; i++) {
376             final double w = frequencies[i];
377             Assumptions.assumeTrue((long) w == frequencies[i]);
378             // Ensure the exponent is set in the event of simple frequencies
379             weights[i] = Math.scalb(w, -35);
380         }
381         final UniformRandomProvider[] rngs = RandomAssert.createRNG(2);
382         final UniformRandomProvider rng1 = rngs[0];
383         final UniformRandomProvider rng2 = rngs[1];
384         final SharedStateDiscreteSampler sampler1 =
385             FastLoadedDiceRollerDiscreteSampler.of(rng1, frequencies);
386         final SharedStateDiscreteSampler sampler2 =
387             FastLoadedDiceRollerDiscreteSampler.of(rng2, weights);
388         RandomAssert.assertProduceSameSequence(sampler1, sampler2);
389     }
390 
391     /**
392      * Test scaled weights. The sampler uses the relative magnitude of weights and the
393      * output should be invariant to scaling. The weights are sampled from the 2^53 dyadic
394      * rationals in [0, 1). A scale factor of -1021 is the lower limit if a weight is
395      * 2^-53 to maintain a non-zero weight. The upper limit is 1023 if a weight is 1 to avoid
396      * infinite values. Note that it does not matter if the sum of weights is infinite; only
397      * the individual weights must be finite.
398      *
399      * @param scaleFactor the scale factor
400      */
401     @ParameterizedTest
402     @ValueSource(ints = {1023, 67, 1, -59, -1020, -1021})
403     void testScaledWeights(int scaleFactor) {
404         // Weights in [0, 1)
405         final double[] w1 = RandomAssert.createRNG().doubles(10).toArray();
406         final double scale = Math.scalb(1.0, scaleFactor);
407         final double[] w2 = Arrays.stream(w1).map(x -> x * scale).toArray();
408         final UniformRandomProvider[] rngs = RandomAssert.createRNG(2);
409         final UniformRandomProvider rng1 = rngs[0];
410         final UniformRandomProvider rng2 = rngs[1];
411         final SharedStateDiscreteSampler sampler1 =
412             FastLoadedDiceRollerDiscreteSampler.of(rng1, w1);
413         final SharedStateDiscreteSampler sampler2 =
414             FastLoadedDiceRollerDiscreteSampler.of(rng2, w2);
415         RandomAssert.assertProduceSameSequence(sampler1, sampler2);
416     }
417 
418     /**
419      * Test the alpha parameter removes small relative weights.
420      * Weights should be removed if they are {@code 2^alpha} smaller than the largest
421      * weight.
422      *
423      * @param alpha Alpha parameter
424      */
425     @ParameterizedTest
426     @ValueSource(ints = {13, 30, 53})
427     void testAlphaRemovesWeights(int alpha) {
428         // The small weight must be > 2^alpha smaller so scale by (alpha + 1)
429         final double small = Math.scalb(1.0, -(alpha + 1));
430         final double[] w1 = {1, 0.5, 0.5, 0};
431         final double[] w2 = {1, 0.5, 0.5, small};
432         final UniformRandomProvider[] rngs = RandomAssert.createRNG(3);
433         final UniformRandomProvider rng1 = rngs[0];
434         final UniformRandomProvider rng2 = rngs[1];
435         final UniformRandomProvider rng3 = rngs[2];
436 
437         final int n = 10;
438         final int[] s1 = FastLoadedDiceRollerDiscreteSampler.of(rng1, w1).samples(n).toArray();
439         final int[] s2 = FastLoadedDiceRollerDiscreteSampler.of(rng2, w2, alpha).samples(n).toArray();
440         final int[] s3 = FastLoadedDiceRollerDiscreteSampler.of(rng3, w2, alpha + 1).samples(n).toArray();
441 
442         Assertions.assertArrayEquals(s1, s2, "alpha parameter should ignore the small weight");
443         Assertions.assertFalse(Arrays.equals(s1, s3), "alpha+1 parameter should not ignore the small weight");
444     }
445 
446     static Stream<long[]> testSharedStateSampler() {
447         return Stream.of(
448             new long[] {42},
449             new long[] {1, 1, 2, 3, 1}
450         );
451     }
452 
453     @ParameterizedTest
454     @MethodSource
455     void testSharedStateSampler(long[] frequencies) {
456         final UniformRandomProvider rng1 = RandomAssert.seededRNG();
457         final UniformRandomProvider rng2 = RandomAssert.seededRNG();
458         final SharedStateDiscreteSampler sampler1 =
459             FastLoadedDiceRollerDiscreteSampler.of(rng1, frequencies);
460         final SharedStateDiscreteSampler sampler2 = sampler1.withUniformRandomProvider(rng2);
461         RandomAssert.assertProduceSameSequence(sampler1, sampler2);
462     }
463 }