View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.rng.sampling;
18  
19  import java.util.ArrayList;
20  import java.util.Collection;
21  import java.util.HashSet;
22  import java.util.LinkedList;
23  import java.util.List;
24  import java.util.ListIterator;
25  import java.util.Set;
26  import java.util.function.Supplier;
27  import org.apache.commons.math3.stat.inference.ChiSquareTest;
28  import org.apache.commons.rng.UniformRandomProvider;
29  import org.junit.jupiter.api.Assertions;
30  import org.junit.jupiter.api.Test;
31  
32  /**
33   * Tests for {@link ListSampler}.
34   */
35  class ListSamplerTest {
36      @Test
37      void testSample() {
38          final String[][] c = {{"0", "1"}, {"0", "2"}, {"0", "3"}, {"0", "4"},
39                                {"1", "2"}, {"1", "3"}, {"1", "4"},
40                                {"2", "3"}, {"2", "4"},
41                                {"3", "4"}};
42          final long[] observed = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
43          final double[] expected = {100, 100, 100, 100, 100, 100, 100, 100, 100, 100};
44  
45          final HashSet<String> cPop = new HashSet<>(); // {0, 1, 2, 3, 4}.
46          for (int i = 0; i < 5; i++) {
47              cPop.add(Integer.toString(i));
48          }
49  
50          final List<Set<String>> sets = new ArrayList<>(); // 2-sets from 5.
51          for (int i = 0; i < 10; i++) {
52              final HashSet<String> hs = new HashSet<>();
53              hs.add(c[i][0]);
54              hs.add(c[i][1]);
55              sets.add(hs);
56          }
57  
58          final UniformRandomProvider rng = RandomAssert.createRNG();
59          for (int i = 0; i < 1000; i++) {
60              observed[findSample(sets, ListSampler.sample(rng, new ArrayList<>(cPop), 2))]++;
61          }
62  
63          // Pass if we cannot reject null hypothesis that distributions are the same.
64          final ChiSquareTest chiSquareTest = new ChiSquareTest();
65          Assertions.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001));
66      }
67  
68      @Test
69      void testSampleWhole() {
70          // Sample of size = size of collection must return the same collection.
71          final List<String> list = new ArrayList<>();
72          list.add("one");
73  
74          final UniformRandomProvider rng = RandomAssert.seededRNG();
75          final List<String> one = ListSampler.sample(rng, list, 1);
76          Assertions.assertEquals(1, one.size());
77          Assertions.assertTrue(one.contains("one"));
78      }
79  
80      @Test
81      void testSamplePrecondition1() {
82          // Must fail for sample size > collection size.
83          final List<String> list = new ArrayList<>();
84          list.add("one");
85          final UniformRandomProvider rng = RandomAssert.seededRNG();
86          Assertions.assertThrows(IllegalArgumentException.class,
87              () -> ListSampler.sample(rng, list, 2));
88      }
89  
90      @Test
91      void testSamplePrecondition2() {
92          // Must fail for empty collection.
93          final List<String> list = new ArrayList<>();
94          final UniformRandomProvider rng = RandomAssert.seededRNG();
95          Assertions.assertThrows(IllegalArgumentException.class,
96              () -> ListSampler.sample(rng, list, 1));
97      }
98  
99      @Test
100     void testShuffle() {
101         final UniformRandomProvider rng = RandomAssert.createRNG();
102         final List<Integer> orig = new ArrayList<>();
103         for (int i = 0; i < 10; i++) {
104             orig.add((i + 1) * rng.nextInt());
105         }
106 
107         final List<Integer> arrayList = new ArrayList<>(orig);
108 
109         ListSampler.shuffle(rng, arrayList);
110         // Ensure that at least one entry has moved.
111         Assertions.assertTrue(compare(orig, arrayList, 0, orig.size(), false), "ArrayList");
112 
113         final List<Integer> linkedList = new LinkedList<>(orig);
114 
115         ListSampler.shuffle(rng, linkedList);
116         // Ensure that at least one entry has moved.
117         Assertions.assertTrue(compare(orig, linkedList, 0, orig.size(), false), "LinkedList");
118     }
119 
120     @Test
121     void testShuffleTail() {
122         final UniformRandomProvider rng = RandomAssert.createRNG();
123         final List<Integer> orig = new ArrayList<>();
124         for (int i = 0; i < 10; i++) {
125             orig.add((i + 1) * rng.nextInt());
126         }
127         final List<Integer> list = new ArrayList<>(orig);
128 
129         final int start = 4;
130         ListSampler.shuffle(rng, list, start, false);
131 
132         // Ensure that all entries below index "start" did not move.
133         Assertions.assertTrue(compare(orig, list, 0, start, true));
134 
135         // Ensure that at least one entry has moved.
136         Assertions.assertTrue(compare(orig, list, start, orig.size(), false));
137     }
138 
139     @Test
140     void testShuffleHead() {
141         final UniformRandomProvider rng = RandomAssert.createRNG();
142         final List<Integer> orig = new ArrayList<>();
143         for (int i = 0; i < 10; i++) {
144             orig.add((i + 1) * rng.nextInt());
145         }
146         final List<Integer> list = new ArrayList<>(orig);
147 
148         final int start = 4;
149         ListSampler.shuffle(rng, list, start, true);
150 
151         // Ensure that all entries above index "start" did not move.
152         Assertions.assertTrue(compare(orig, list, start + 1, orig.size(), true));
153 
154         // Ensure that at least one entry has moved.
155         Assertions.assertTrue(compare(orig, list, 0, start + 1, false));
156     }
157 
158     /**
159      * Test shuffle matches {@link PermutationSampler#shuffle(UniformRandomProvider, int[])}.
160      * The implementation may be different but the result is a Fisher-Yates shuffle so the
161      * output order should match.
162      */
163     @Test
164     void testShuffleMatchesPermutationSamplerShuffle() {
165         final UniformRandomProvider rng = RandomAssert.seededRNG();
166         final List<Integer> orig = new ArrayList<>();
167         for (int i = 0; i < 10; i++) {
168             orig.add((i + 1) * rng.nextInt());
169         }
170 
171         assertShuffleMatchesPermutationSamplerShuffle(new ArrayList<>(orig));
172         assertShuffleMatchesPermutationSamplerShuffle(new LinkedList<>(orig));
173     }
174 
175     /**
176      * Test shuffle matches {@link PermutationSampler#shuffle(UniformRandomProvider, int[], int, boolean)}.
177      * The implementation may be different but the result is a Fisher-Yates shuffle so the
178      * output order should match.
179      */
180     @Test
181     void testShuffleMatchesPermutationSamplerShuffleDirectional() {
182         final UniformRandomProvider rng = RandomAssert.seededRNG();
183         final List<Integer> orig = new ArrayList<>();
184         for (int i = 0; i < 10; i++) {
185             orig.add((i + 1) * rng.nextInt());
186         }
187 
188         assertShuffleMatchesPermutationSamplerShuffle(new ArrayList<>(orig), 4, true);
189         assertShuffleMatchesPermutationSamplerShuffle(new ArrayList<>(orig), 4, false);
190         assertShuffleMatchesPermutationSamplerShuffle(new LinkedList<>(orig), 4, true);
191         assertShuffleMatchesPermutationSamplerShuffle(new LinkedList<>(orig), 4, false);
192     }
193 
194     /**
195      * This test hits the edge case when a LinkedList is small enough that the algorithm
196      * using a RandomAccess list is faster than the one with an iterator.
197      */
198     @Test
199     void testShuffleWithSmallLinkedList() {
200         final UniformRandomProvider rng = RandomAssert.seededRNG();
201         final int size = 3;
202         final List<Integer> orig = new ArrayList<>();
203         for (int i = 0; i < size; i++) {
204             orig.add((i + 1) * rng.nextInt());
205         }
206 
207         // When the size is small there is a chance that the list has no entries that move.
208         // E.g. The number of permutations of 3 items is only 6 giving a 1/6 chance of no change.
209         // So repeat test that the small shuffle matches the PermutationSampler.
210         // 10 times is (1/6)^10 or 1 in 60,466,176 of no change.
211         for (int i = 0; i < 10; i++) {
212             assertShuffleMatchesPermutationSamplerShuffle(new LinkedList<>(orig), size - 1, true);
213         }
214     }
215 
216     //// Support methods.
217 
218     /**
219      * If {@code same == true}, return {@code true} if all entries are
220      * the same; if {@code same == false}, return {@code true} if at
221      * least one entry is different.
222      */
223     private static <T> boolean compare(List<T> orig,
224                                        List<T> list,
225                                        int start,
226                                        int end,
227                                        boolean same) {
228         for (int i = start; i < end; i++) {
229             if (!orig.get(i).equals(list.get(i))) {
230                 return !same;
231             }
232         }
233         return same;
234     }
235 
236     private static <T extends Set<String>> int findSample(List<T> u,
237                                                           Collection<String> sampList) {
238         final String[] samp = sampList.toArray(new String[sampList.size()]);
239         for (int i = 0; i < u.size(); i++) {
240             final T set = u.get(i);
241             final HashSet<String> sampSet = new HashSet<>();
242             for (int j = 0; j < samp.length; j++) {
243                 sampSet.add(samp[j]);
244             }
245             if (set.equals(sampSet)) {
246                 return i;
247             }
248         }
249 
250         Assertions.fail("Sample not found: { " + samp[0] + ", " + samp[1] + " }");
251         return -1;
252     }
253 
254     /**
255      * Assert the shuffle matches {@link PermutationSampler#shuffle(UniformRandomProvider, int[])}.
256      *
257      * @param list Array whose entries will be shuffled (in-place).
258      */
259     private static void assertShuffleMatchesPermutationSamplerShuffle(List<Integer> list) {
260         final int[] array = new int[list.size()];
261         ListIterator<Integer> it = list.listIterator();
262         for (int i = 0; i < array.length; i++) {
263             array[i] = it.next();
264         }
265 
266         // Identical RNGs
267         final UniformRandomProvider rng1 = RandomAssert.seededRNG();
268         final UniformRandomProvider rng2 = RandomAssert.seededRNG();
269 
270         ListSampler.shuffle(rng1, list);
271         PermutationSampler.shuffle(rng2, array);
272 
273         final Supplier<String> msg = () -> "Type=" + list.getClass().getSimpleName();
274         it = list.listIterator();
275         for (int i = 0; i < array.length; i++) {
276             Assertions.assertEquals(array[i], it.next().intValue(), msg);
277         }
278     }
279     /**
280      * Assert the shuffle matches {@link PermutationSampler#shuffle(UniformRandomProvider, int[], int, boolean)}.
281      *
282      * @param list Array whose entries will be shuffled (in-place).
283      * @param start Index at which shuffling begins.
284      * @param towardHead Shuffling is performed for index positions between
285      * {@code start} and either the end (if {@code false}) or the beginning
286      * (if {@code true}) of the array.
287      */
288     private static void assertShuffleMatchesPermutationSamplerShuffle(List<Integer> list,
289                                                                       int start,
290                                                                       boolean towardHead) {
291         final int[] array = new int[list.size()];
292         ListIterator<Integer> it = list.listIterator();
293         for (int i = 0; i < array.length; i++) {
294             array[i] = it.next();
295         }
296 
297         // Identical RNGs
298         final UniformRandomProvider rng1 = RandomAssert.seededRNG();
299         final UniformRandomProvider rng2 = RandomAssert.seededRNG();
300 
301         ListSampler.shuffle(rng1, list, start, towardHead);
302         PermutationSampler.shuffle(rng2, array, start, towardHead);
303 
304         final Supplier<String> msg = () -> String.format("Type=%s start=%d towardHead=%b",
305                 list.getClass().getSimpleName(), start, towardHead);
306         it = list.listIterator();
307         for (int i = 0; i < array.length; i++) {
308             Assertions.assertEquals(array[i], it.next().intValue(), msg);
309         }
310     }
311 }