View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.rng.sampling;
18  
19  import java.util.Arrays;
20  import org.apache.commons.math3.stat.inference.ChiSquareTest;
21  import org.apache.commons.math3.util.CombinatoricsUtils;
22  import org.apache.commons.rng.UniformRandomProvider;
23  import org.junit.jupiter.api.Assertions;
24  import org.junit.jupiter.api.Test;
25  
26  /**
27   * Tests for {@link CombinationSampler}.
28   */
29  class CombinationSamplerTest {
30      @Test
31      void testSampleIsInDomain() {
32          final UniformRandomProvider rng = RandomAssert.seededRNG();
33          final int n = 6;
34          for (int k = 1; k <= n; k++) {
35              final CombinationSampler sampler = new CombinationSampler(rng, n, k);
36              final int[] random = sampler.sample();
37              for (final int s : random) {
38                  assertIsInDomain(n, s);
39              }
40          }
41      }
42  
43      @Test
44      void testUniformWithKlessThanHalfN() {
45          final int n = 8;
46          final int k = 2;
47          assertUniformSamples(n, k);
48      }
49  
50      @Test
51      void testUniformWithKmoreThanHalfN() {
52          final int n = 8;
53          final int k = 6;
54          assertUniformSamples(n, k);
55      }
56  
57      @Test
58      void testSampleWhenNequalsKIsNotShuffled() {
59          final UniformRandomProvider rng = RandomAssert.seededRNG();
60          // Check n == k boundary case.
61          // This is allowed but the sample is not shuffled.
62          for (int n = 1; n < 3; n++) {
63              final int k = n;
64              final CombinationSampler sampler = new CombinationSampler(rng, n, k);
65              final int[] sample = sampler.sample();
66              Assertions.assertEquals(n, sample.length, "Incorrect sample length");
67              for (int i = 0; i < n; i++) {
68                  Assertions.assertEquals(i, sample[i], "Sample was shuffled");
69              }
70          }
71      }
72  
73      @Test
74      void testKgreaterThanNThrows() {
75          final UniformRandomProvider rng = RandomAssert.seededRNG();
76          // Must fail for k > n.
77          final int n = 2;
78          final int k = 3;
79          Assertions.assertThrows(IllegalArgumentException.class,
80              () -> new CombinationSampler(rng, n, k));
81      }
82  
83      @Test
84      void testNequalsZeroThrows() {
85          final UniformRandomProvider rng = RandomAssert.seededRNG();
86          // Must fail for n = 0.
87          final int n = 0;
88          final int k = 3;
89          Assertions.assertThrows(IllegalArgumentException.class,
90              () -> new CombinationSampler(rng, n, k));
91      }
92  
93      @Test
94      void testKequalsZeroThrows() {
95          final UniformRandomProvider rng = RandomAssert.seededRNG();
96          // Must fail for k = 0.
97          final int n = 2;
98          final int k = 0;
99          Assertions.assertThrows(IllegalArgumentException.class,
100             () -> new CombinationSampler(rng, n, k));
101     }
102 
103     @Test
104     void testNisNegativeThrows() {
105         final UniformRandomProvider rng = RandomAssert.seededRNG();
106         // Must fail for n <= 0.
107         final int n = -1;
108         final int k = 3;
109         Assertions.assertThrows(IllegalArgumentException.class,
110             () -> new CombinationSampler(rng, n, k));
111     }
112 
113     @Test
114     void testKisNegativeThrows() {
115         final UniformRandomProvider rng = RandomAssert.seededRNG();
116         // Must fail for k <= 0.
117         final int n = 0;
118         final int k = -1;
119         Assertions.assertThrows(IllegalArgumentException.class,
120             () -> new CombinationSampler(rng, n, k));
121     }
122 
123     /**
124      * Test the SharedStateSampler implementation.
125      */
126     @Test
127     void testSharedStateSampler() {
128         final UniformRandomProvider rng1 = RandomAssert.seededRNG();
129         final UniformRandomProvider rng2 = RandomAssert.seededRNG();
130         final int n = 17;
131         final int k = 3;
132         final CombinationSampler sampler1 =
133             new CombinationSampler(rng1, n, k);
134         final CombinationSampler sampler2 = sampler1.withUniformRandomProvider(rng2);
135         RandomAssert.assertProduceSameSequence(sampler1, sampler2);
136     }
137 
138     //// Support methods.
139 
140     /**
141      * Asserts the sample value is in the range 0 to n-1.
142      *
143      * @param n     the n
144      * @param value the sample value
145      */
146     private static void assertIsInDomain(int n, int value) {
147         if (value < 0 || value >= n) {
148             Assertions.fail("sample " + value + " not in the domain " + n);
149         }
150     }
151 
152     private void assertUniformSamples(int n, int k) {
153         // The C(n, k) should generate a sample of unspecified order.
154         // To test this each combination is allocated a unique code
155         // based on setting k of the first n-bits in an integer.
156         // Codes are positive for all combinations of bits that use k-bits,
157         // otherwise they are negative.
158         final int totalBitCombinations = 1 << n;
159         final int[] codeLookup = new int[totalBitCombinations];
160         Arrays.fill(codeLookup, -1); // initialize as negative
161         int codes = 0;
162         for (int i = 0; i < totalBitCombinations; i++) {
163             if (Integer.bitCount(i) == k) {
164                 // This is a valid sample so allocate a code
165                 codeLookup[i] = codes++;
166             }
167         }
168 
169         // The number of combinations C(n, k) is the binomial coefficient
170         Assertions.assertEquals(CombinatoricsUtils.binomialCoefficient(n, k), codes,
171             "Incorrect number of combination codes");
172 
173         final long[] observed = new long[codes];
174         final int numSamples = 6000;
175 
176         final UniformRandomProvider rng = RandomAssert.createRNG();
177         final CombinationSampler sampler = new CombinationSampler(rng, n, k);
178         for (int i = 0; i < numSamples; i++) {
179             observed[findCode(codeLookup, sampler.sample())]++;
180         }
181 
182         // Chi squared test of uniformity
183         final double numExpected = numSamples / (double) codes;
184         final double[] expected = new double[codes];
185         Arrays.fill(expected, numExpected);
186         final ChiSquareTest chiSquareTest = new ChiSquareTest();
187         // Pass if we cannot reject null hypothesis that distributions are the same.
188         Assertions.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001));
189     }
190 
191     private static int findCode(int[] codeLookup, int[] sample) {
192         // Each sample index is used to set a bit in an integer.
193         // The resulting bits should be a valid code.
194         int bits = 0;
195         for (final int s : sample) {
196             // This shift will be from 0 to n-1 since it is from the
197             // domain of size n.
198             bits |= 1 << s;
199         }
200         if (bits >= codeLookup.length) {
201             Assertions.fail("Bad bit combination: " + Arrays.toString(sample));
202         }
203         final int code = codeLookup[bits];
204         if (code < 0) {
205             Assertions.fail("Bad bit code: " + Arrays.toString(sample));
206         }
207         return code;
208     }
209 }