1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.sampling;
18
19 import java.util.Arrays;
20 import org.apache.commons.math3.stat.inference.ChiSquareTest;
21 import org.apache.commons.rng.UniformRandomProvider;
22 import org.junit.jupiter.api.Assertions;
23 import org.junit.jupiter.api.Test;
24
25
26
27
28 class PermutationSamplerTest {
29 @Test
30 void testSampleTrivial() {
31 final int n = 6;
32 final int k = 3;
33 final PermutationSampler sampler = new PermutationSampler(RandomAssert.seededRNG(),
34 n, k);
35 final int[] random = sampler.sample();
36 SAMPLE: for (final int s : random) {
37 for (int i = 0; i < n; i++) {
38 if (i == s) {
39 continue SAMPLE;
40 }
41 }
42 Assertions.fail("number " + s + " not in array");
43 }
44 }
45
46 @Test
47 void testSampleChiSquareTest() {
48 final int n = 3;
49 final int k = 3;
50 final int[][] p = {{0, 1, 2}, {0, 2, 1},
51 {1, 0, 2}, {1, 2, 0},
52 {2, 0, 1}, {2, 1, 0}};
53 runSampleChiSquareTest(n, k, p);
54 }
55
56 @Test
57 void testSubSampleChiSquareTest() {
58 final int n = 4;
59 final int k = 2;
60 final int[][] p = {{0, 1}, {1, 0},
61 {0, 2}, {2, 0},
62 {0, 3}, {3, 0},
63 {1, 2}, {2, 1},
64 {1, 3}, {3, 1},
65 {2, 3}, {3, 2}};
66 runSampleChiSquareTest(n, k, p);
67 }
68
69 @Test
70 void testSampleBoundaryCase() {
71 final UniformRandomProvider rng = RandomAssert.seededRNG();
72
73 final PermutationSampler sampler = new PermutationSampler(rng, 1, 1);
74 final int[] perm = sampler.sample();
75 Assertions.assertEquals(1, perm.length);
76 Assertions.assertEquals(0, perm[0]);
77 }
78
79 @Test
80 void testSamplePrecondition1() {
81 final UniformRandomProvider rng = RandomAssert.seededRNG();
82
83 Assertions.assertThrows(IllegalArgumentException.class,
84 () -> new PermutationSampler(rng, 2, 3));
85 }
86
87 @Test
88 void testSamplePrecondition2() {
89 final UniformRandomProvider rng = RandomAssert.seededRNG();
90
91 Assertions.assertThrows(IllegalArgumentException.class,
92 () -> new PermutationSampler(rng, 0, 0));
93 }
94
95 @Test
96 void testSamplePrecondition3() {
97 final UniformRandomProvider rng = RandomAssert.seededRNG();
98
99 Assertions.assertThrows(IllegalArgumentException.class,
100 () -> new PermutationSampler(rng, -1, 0));
101 }
102
103 @Test
104 void testSamplePrecondition4() {
105 final UniformRandomProvider rng = RandomAssert.seededRNG();
106
107 Assertions.assertThrows(IllegalArgumentException.class,
108 () -> new PermutationSampler(rng, 1, -1));
109 }
110
111 @Test
112 void testNatural() {
113 final int n = 4;
114 final int[] expected = {0, 1, 2, 3};
115
116 final int[] natural = PermutationSampler.natural(n);
117 for (int i = 0; i < n; i++) {
118 Assertions.assertEquals(expected[i], natural[i]);
119 }
120 }
121
122 @Test
123 void testNaturalZero() {
124 final int[] natural = PermutationSampler.natural(0);
125 Assertions.assertEquals(0, natural.length);
126 }
127
128 @Test
129 void testShuffleNoDuplicates() {
130 final int n = 100;
131 final int[] orig = PermutationSampler.natural(n);
132 final UniformRandomProvider rng = RandomAssert.seededRNG();
133 PermutationSampler.shuffle(rng, orig);
134
135
136 final int[] count = new int[n];
137 for (int i = 0; i < n; i++) {
138 count[orig[i]] += 1;
139 }
140
141 for (int i = 0; i < n; i++) {
142 Assertions.assertEquals(1, count[i]);
143 }
144 }
145
146 @Test
147 void testShuffleTail() {
148 final int[] orig = new int[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
149 final int[] list = orig.clone();
150 final int start = 4;
151 final UniformRandomProvider rng = RandomAssert.createRNG();
152 PermutationSampler.shuffle(rng, list, start, false);
153
154
155 for (int i = 0; i < start; i++) {
156 Assertions.assertEquals(orig[i], list[i]);
157 }
158
159
160 boolean ok = false;
161 for (int i = start; i < orig.length - 1; i++) {
162 if (orig[i] != list[i]) {
163 ok = true;
164 break;
165 }
166 }
167 Assertions.assertTrue(ok);
168 }
169
170 @Test
171 void testShuffleHead() {
172 final int[] orig = new int[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
173 final int[] list = orig.clone();
174 final int start = 4;
175 final UniformRandomProvider rng = RandomAssert.createRNG();
176 PermutationSampler.shuffle(rng, list, start, true);
177
178
179 for (int i = start + 1; i < orig.length; i++) {
180 Assertions.assertEquals(orig[i], list[i]);
181 }
182
183
184 boolean ok = false;
185 for (int i = 0; i <= start; i++) {
186 if (orig[i] != list[i]) {
187 ok = true;
188 break;
189 }
190 }
191 Assertions.assertTrue(ok);
192 }
193
194
195
196
197 @Test
198 void testSharedStateSampler() {
199 final UniformRandomProvider rng1 = RandomAssert.seededRNG();
200 final UniformRandomProvider rng2 = RandomAssert.seededRNG();
201 final int n = 17;
202 final int k = 13;
203 final PermutationSampler sampler1 =
204 new PermutationSampler(rng1, n, k);
205 final PermutationSampler sampler2 = sampler1.withUniformRandomProvider(rng2);
206 RandomAssert.assertProduceSameSequence(sampler1, sampler2);
207 }
208
209
210
211 private void runSampleChiSquareTest(int n,
212 int k,
213 int[][] p) {
214 final int len = p.length;
215 final long[] observed = new long[len];
216 final int numSamples = 6000;
217 final double numExpected = numSamples / (double) len;
218 final double[] expected = new double[len];
219 Arrays.fill(expected, numExpected);
220
221 final UniformRandomProvider rng = RandomAssert.createRNG();
222 final PermutationSampler sampler = new PermutationSampler(rng, n, k);
223 for (int i = 0; i < numSamples; i++) {
224 observed[findPerm(p, sampler.sample())]++;
225 }
226
227
228 final ChiSquareTest chiSquareTest = new ChiSquareTest();
229 Assertions.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001));
230 }
231
232 private static int findPerm(int[][] p,
233 int[] samp) {
234 for (int i = 0; i < p.length; i++) {
235 if (Arrays.equals(p[i], samp)) {
236 return i;
237 }
238 }
239 Assertions.fail("Permutation not found");
240 return -1;
241 }
242 }