1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.sampling;
18
19 import java.util.Arrays;
20 import org.apache.commons.math3.stat.inference.ChiSquareTest;
21 import org.apache.commons.math3.util.CombinatoricsUtils;
22 import org.apache.commons.rng.UniformRandomProvider;
23 import org.junit.jupiter.api.Assertions;
24 import org.junit.jupiter.api.Test;
25
26
27
28
29 class CombinationSamplerTest {
30 @Test
31 void testSampleIsInDomain() {
32 final UniformRandomProvider rng = RandomAssert.seededRNG();
33 final int n = 6;
34 for (int k = 1; k <= n; k++) {
35 final CombinationSampler sampler = new CombinationSampler(rng, n, k);
36 final int[] random = sampler.sample();
37 for (final int s : random) {
38 assertIsInDomain(n, s);
39 }
40 }
41 }
42
43 @Test
44 void testUniformWithKlessThanHalfN() {
45 final int n = 8;
46 final int k = 2;
47 assertUniformSamples(n, k);
48 }
49
50 @Test
51 void testUniformWithKmoreThanHalfN() {
52 final int n = 8;
53 final int k = 6;
54 assertUniformSamples(n, k);
55 }
56
57 @Test
58 void testSampleWhenNequalsKIsNotShuffled() {
59 final UniformRandomProvider rng = RandomAssert.seededRNG();
60
61
62 for (int n = 1; n < 3; n++) {
63 final int k = n;
64 final CombinationSampler sampler = new CombinationSampler(rng, n, k);
65 final int[] sample = sampler.sample();
66 Assertions.assertEquals(n, sample.length, "Incorrect sample length");
67 for (int i = 0; i < n; i++) {
68 Assertions.assertEquals(i, sample[i], "Sample was shuffled");
69 }
70 }
71 }
72
73 @Test
74 void testKgreaterThanNThrows() {
75 final UniformRandomProvider rng = RandomAssert.seededRNG();
76
77 final int n = 2;
78 final int k = 3;
79 Assertions.assertThrows(IllegalArgumentException.class,
80 () -> new CombinationSampler(rng, n, k));
81 }
82
83 @Test
84 void testNequalsZeroThrows() {
85 final UniformRandomProvider rng = RandomAssert.seededRNG();
86
87 final int n = 0;
88 final int k = 3;
89 Assertions.assertThrows(IllegalArgumentException.class,
90 () -> new CombinationSampler(rng, n, k));
91 }
92
93 @Test
94 void testKequalsZeroThrows() {
95 final UniformRandomProvider rng = RandomAssert.seededRNG();
96
97 final int n = 2;
98 final int k = 0;
99 Assertions.assertThrows(IllegalArgumentException.class,
100 () -> new CombinationSampler(rng, n, k));
101 }
102
103 @Test
104 void testNisNegativeThrows() {
105 final UniformRandomProvider rng = RandomAssert.seededRNG();
106
107 final int n = -1;
108 final int k = 3;
109 Assertions.assertThrows(IllegalArgumentException.class,
110 () -> new CombinationSampler(rng, n, k));
111 }
112
113 @Test
114 void testKisNegativeThrows() {
115 final UniformRandomProvider rng = RandomAssert.seededRNG();
116
117 final int n = 0;
118 final int k = -1;
119 Assertions.assertThrows(IllegalArgumentException.class,
120 () -> new CombinationSampler(rng, n, k));
121 }
122
123
124
125
126 @Test
127 void testSharedStateSampler() {
128 final UniformRandomProvider rng1 = RandomAssert.seededRNG();
129 final UniformRandomProvider rng2 = RandomAssert.seededRNG();
130 final int n = 17;
131 final int k = 3;
132 final CombinationSampler sampler1 =
133 new CombinationSampler(rng1, n, k);
134 final CombinationSampler sampler2 = sampler1.withUniformRandomProvider(rng2);
135 RandomAssert.assertProduceSameSequence(sampler1, sampler2);
136 }
137
138
139
140
141
142
143
144
145
146 private static void assertIsInDomain(int n, int value) {
147 if (value < 0 || value >= n) {
148 Assertions.fail("sample " + value + " not in the domain " + n);
149 }
150 }
151
152 private void assertUniformSamples(int n, int k) {
153
154
155
156
157
158 final int totalBitCombinations = 1 << n;
159 final int[] codeLookup = new int[totalBitCombinations];
160 Arrays.fill(codeLookup, -1);
161 int codes = 0;
162 for (int i = 0; i < totalBitCombinations; i++) {
163 if (Integer.bitCount(i) == k) {
164
165 codeLookup[i] = codes++;
166 }
167 }
168
169
170 Assertions.assertEquals(CombinatoricsUtils.binomialCoefficient(n, k), codes,
171 "Incorrect number of combination codes");
172
173 final long[] observed = new long[codes];
174 final int numSamples = 6000;
175
176 final UniformRandomProvider rng = RandomAssert.createRNG();
177 final CombinationSampler sampler = new CombinationSampler(rng, n, k);
178 for (int i = 0; i < numSamples; i++) {
179 observed[findCode(codeLookup, sampler.sample())]++;
180 }
181
182
183 final double numExpected = numSamples / (double) codes;
184 final double[] expected = new double[codes];
185 Arrays.fill(expected, numExpected);
186 final ChiSquareTest chiSquareTest = new ChiSquareTest();
187
188 Assertions.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001));
189 }
190
191 private static int findCode(int[] codeLookup, int[] sample) {
192
193
194 int bits = 0;
195 for (final int s : sample) {
196
197
198 bits |= 1 << s;
199 }
200 if (bits >= codeLookup.length) {
201 Assertions.fail("Bad bit combination: " + Arrays.toString(sample));
202 }
203 final int code = codeLookup[bits];
204 if (code < 0) {
205 Assertions.fail("Bad bit code: " + Arrays.toString(sample));
206 }
207 return code;
208 }
209 }