View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.rng.sampling;
18  
19  import java.util.Arrays;
20  import org.apache.commons.math3.stat.inference.ChiSquareTest;
21  import org.apache.commons.rng.UniformRandomProvider;
22  import org.junit.jupiter.api.Assertions;
23  import org.junit.jupiter.api.Test;
24  
25  /**
26   * Tests for {@link PermutationSampler}.
27   */
28  class PermutationSamplerTest {
29      @Test
30      void testSampleTrivial() {
31          final int n = 6;
32          final int k = 3;
33          final PermutationSampler sampler = new PermutationSampler(RandomAssert.seededRNG(),
34                                                                    n, k);
35          final int[] random = sampler.sample();
36          SAMPLE: for (final int s : random) {
37              for (int i = 0; i < n; i++) {
38                  if (i == s) {
39                      continue SAMPLE;
40                  }
41              }
42              Assertions.fail("number " + s + " not in array");
43          }
44      }
45  
46      @Test
47      void testSampleChiSquareTest() {
48          final int n = 3;
49          final int k = 3;
50          final int[][] p = {{0, 1, 2}, {0, 2, 1},
51                             {1, 0, 2}, {1, 2, 0},
52                             {2, 0, 1}, {2, 1, 0}};
53          runSampleChiSquareTest(n, k, p);
54      }
55  
56      @Test
57      void testSubSampleChiSquareTest() {
58          final int n = 4;
59          final int k = 2;
60          final int[][] p = {{0, 1}, {1, 0},
61                             {0, 2}, {2, 0},
62                             {0, 3}, {3, 0},
63                             {1, 2}, {2, 1},
64                             {1, 3}, {3, 1},
65                             {2, 3}, {3, 2}};
66          runSampleChiSquareTest(n, k, p);
67      }
68  
69      @Test
70      void testSampleBoundaryCase() {
71          final UniformRandomProvider rng = RandomAssert.seededRNG();
72          // Check size = 1 boundary case.
73          final PermutationSampler sampler = new PermutationSampler(rng, 1, 1);
74          final int[] perm = sampler.sample();
75          Assertions.assertEquals(1, perm.length);
76          Assertions.assertEquals(0, perm[0]);
77      }
78  
79      @Test
80      void testSamplePrecondition1() {
81          final UniformRandomProvider rng = RandomAssert.seededRNG();
82          // Must fail for k > n.
83          Assertions.assertThrows(IllegalArgumentException.class,
84              () -> new PermutationSampler(rng, 2, 3));
85      }
86  
87      @Test
88      void testSamplePrecondition2() {
89          final UniformRandomProvider rng = RandomAssert.seededRNG();
90          // Must fail for n = 0.
91          Assertions.assertThrows(IllegalArgumentException.class,
92              () -> new PermutationSampler(rng, 0, 0));
93      }
94  
95      @Test
96      void testSamplePrecondition3() {
97          final UniformRandomProvider rng = RandomAssert.seededRNG();
98          // Must fail for k < n < 0.
99          Assertions.assertThrows(IllegalArgumentException.class,
100             () -> new PermutationSampler(rng, -1, 0));
101     }
102 
103     @Test
104     void testSamplePrecondition4() {
105         final UniformRandomProvider rng = RandomAssert.seededRNG();
106         // Must fail for k < n < 0.
107         Assertions.assertThrows(IllegalArgumentException.class,
108             () -> new PermutationSampler(rng, 1, -1));
109     }
110 
111     @Test
112     void testNatural() {
113         final int n = 4;
114         final int[] expected = {0, 1, 2, 3};
115 
116         final int[] natural = PermutationSampler.natural(n);
117         for (int i = 0; i < n; i++) {
118             Assertions.assertEquals(expected[i], natural[i]);
119         }
120     }
121 
122     @Test
123     void testNaturalZero() {
124         final int[] natural = PermutationSampler.natural(0);
125         Assertions.assertEquals(0, natural.length);
126     }
127 
128     @Test
129     void testShuffleNoDuplicates() {
130         final int n = 100;
131         final int[] orig = PermutationSampler.natural(n);
132         final UniformRandomProvider rng = RandomAssert.seededRNG();
133         PermutationSampler.shuffle(rng, orig);
134 
135         // Test that all (unique) entries exist in the shuffled array.
136         final int[] count = new int[n];
137         for (int i = 0; i < n; i++) {
138             count[orig[i]] += 1;
139         }
140 
141         for (int i = 0; i < n; i++) {
142             Assertions.assertEquals(1, count[i]);
143         }
144     }
145 
146     @Test
147     void testShuffleTail() {
148         final int[] orig = new int[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
149         final int[] list = orig.clone();
150         final int start = 4;
151         final UniformRandomProvider rng = RandomAssert.createRNG();
152         PermutationSampler.shuffle(rng, list, start, false);
153 
154         // Ensure that all entries below index "start" did not move.
155         for (int i = 0; i < start; i++) {
156             Assertions.assertEquals(orig[i], list[i]);
157         }
158 
159         // Ensure that at least one entry has moved.
160         boolean ok = false;
161         for (int i = start; i < orig.length - 1; i++) {
162             if (orig[i] != list[i]) {
163                 ok = true;
164                 break;
165             }
166         }
167         Assertions.assertTrue(ok);
168     }
169 
170     @Test
171     void testShuffleHead() {
172         final int[] orig = new int[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
173         final int[] list = orig.clone();
174         final int start = 4;
175         final UniformRandomProvider rng = RandomAssert.createRNG();
176         PermutationSampler.shuffle(rng, list, start, true);
177 
178         // Ensure that all entries above index "start" did not move.
179         for (int i = start + 1; i < orig.length; i++) {
180             Assertions.assertEquals(orig[i], list[i]);
181         }
182 
183         // Ensure that at least one entry has moved.
184         boolean ok = false;
185         for (int i = 0; i <= start; i++) {
186             if (orig[i] != list[i]) {
187                 ok = true;
188                 break;
189             }
190         }
191         Assertions.assertTrue(ok);
192     }
193 
194     /**
195      * Test the SharedStateSampler implementation.
196      */
197     @Test
198     void testSharedStateSampler() {
199         final UniformRandomProvider rng1 = RandomAssert.seededRNG();
200         final UniformRandomProvider rng2 = RandomAssert.seededRNG();
201         final int n = 17;
202         final int k = 13;
203         final PermutationSampler sampler1 =
204             new PermutationSampler(rng1, n, k);
205         final PermutationSampler sampler2 = sampler1.withUniformRandomProvider(rng2);
206         RandomAssert.assertProduceSameSequence(sampler1, sampler2);
207     }
208 
209     //// Support methods.
210 
211     private void runSampleChiSquareTest(int n,
212                                         int k,
213                                         int[][] p) {
214         final int len = p.length;
215         final long[] observed = new long[len];
216         final int numSamples = 6000;
217         final double numExpected = numSamples / (double) len;
218         final double[] expected = new double[len];
219         Arrays.fill(expected, numExpected);
220 
221         final UniformRandomProvider rng = RandomAssert.createRNG();
222         final PermutationSampler sampler = new PermutationSampler(rng, n, k);
223         for (int i = 0; i < numSamples; i++) {
224             observed[findPerm(p, sampler.sample())]++;
225         }
226 
227         // Pass if we cannot reject null hypothesis that distributions are the same.
228         final ChiSquareTest chiSquareTest = new ChiSquareTest();
229         Assertions.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001));
230     }
231 
232     private static int findPerm(int[][] p,
233                                 int[] samp) {
234         for (int i = 0; i < p.length; i++) {
235             if (Arrays.equals(p[i], samp)) {
236                 return i;
237             }
238         }
239         Assertions.fail("Permutation not found");
240         return -1;
241     }
242 }