1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.rng.sampling;
19
20 import java.util.Arrays;
21 import java.util.Collections;
22 import java.util.List;
23 import java.util.Map;
24 import java.util.TreeMap;
25 import org.apache.commons.rng.UniformRandomProvider;
26 import org.junit.jupiter.api.Assertions;
27 import org.junit.jupiter.api.Test;
28
29
30
31
32 class DiscreteProbabilityCollectionSamplerTest {
33
34 private final UniformRandomProvider rng = RandomAssert.createRNG();
35
36 @Test
37 void testPrecondition1() {
38
39 final List<Double> collection = Arrays.asList(1d, 2d);
40 final double[] probabilities = {0};
41 Assertions.assertThrows(IllegalArgumentException.class,
42 () -> new DiscreteProbabilityCollectionSampler<>(rng,
43 collection,
44 probabilities));
45 }
46
47 @Test
48 void testPrecondition2() {
49
50 final List<Double> collection = Arrays.asList(1d, 2d);
51 final double[] probabilities = {0, -1};
52 Assertions.assertThrows(IllegalArgumentException.class,
53 () -> new DiscreteProbabilityCollectionSampler<>(rng,
54 collection,
55 probabilities));
56 }
57
58 @Test
59 void testPrecondition3() {
60
61 final List<Double> collection = Arrays.asList(1d, 2d);
62 final double[] probabilities = {0, 0};
63 Assertions.assertThrows(IllegalArgumentException.class,
64 () -> new DiscreteProbabilityCollectionSampler<>(rng,
65 collection,
66 probabilities));
67 }
68
69 @Test
70 void testPrecondition4() {
71
72 final List<Double> collection = Arrays.asList(1d, 2d);
73 final double[] probabilities = {0, Double.NaN};
74 Assertions.assertThrows(IllegalArgumentException.class,
75 () -> new DiscreteProbabilityCollectionSampler<>(rng,
76 collection,
77 probabilities));
78 }
79
80 @Test
81 void testPrecondition5() {
82
83 final List<Double> collection = Arrays.asList(1d, 2d);
84 final double[] probabilities = {0, Double.POSITIVE_INFINITY};
85 Assertions.assertThrows(IllegalArgumentException.class,
86 () -> new DiscreteProbabilityCollectionSampler<>(rng,
87 collection,
88 probabilities));
89 }
90
91 @Test
92 void testPrecondition6() {
93
94 final Map<String, Double> collection = Collections.emptyMap();
95 Assertions.assertThrows(IllegalArgumentException.class,
96 () -> new DiscreteProbabilityCollectionSampler<>(rng,
97 collection));
98 }
99
100 @Test
101 void testPrecondition7() {
102
103 final List<Double> collection = Collections.emptyList();
104 final double[] probabilities = {};
105 Assertions.assertThrows(IllegalArgumentException.class,
106 () -> new DiscreteProbabilityCollectionSampler<>(rng,
107 collection,
108 probabilities));
109 }
110
111 @Test
112 void testSample() {
113 final DiscreteProbabilityCollectionSampler<Double> sampler =
114 new DiscreteProbabilityCollectionSampler<>(rng,
115 Arrays.asList(3d, -1d, 3d, 7d, -2d, 8d),
116 new double[] {0.2, 0.2, 0.3, 0.3, 0, 0});
117 final double expectedMean = 3.4;
118 final double expectedVariance = 7.84;
119
120 final int n = 100000000;
121 double sum = 0;
122 double sumOfSquares = 0;
123 for (int i = 0; i < n; i++) {
124 final double rand = sampler.sample();
125 sum += rand;
126 sumOfSquares += rand * rand;
127 }
128
129 final double mean = sum / n;
130 Assertions.assertEquals(expectedMean, mean, 1e-3);
131 final double variance = sumOfSquares / n - mean * mean;
132 Assertions.assertEquals(expectedVariance, variance, 2e-3);
133 }
134
135
136 @Test
137 void testSampleUsingMap() {
138 final UniformRandomProvider rng1 = RandomAssert.seededRNG();
139 final UniformRandomProvider rng2 = RandomAssert.seededRNG();
140 final List<Integer> items = Arrays.asList(1, 3, 4, 6, 9);
141 final double[] probabilities = {0.1, 0.2, 0.3, 0.4, 0.5};
142 final DiscreteProbabilityCollectionSampler<Integer> sampler1 =
143 new DiscreteProbabilityCollectionSampler<>(rng1, items, probabilities);
144
145
146 final Map<Integer, Double> map = new TreeMap<>();
147 for (int i = 0; i < probabilities.length; i++) {
148 map.put(items.get(i), probabilities[i]);
149 }
150 final DiscreteProbabilityCollectionSampler<Integer> sampler2 =
151 new DiscreteProbabilityCollectionSampler<>(rng2, map);
152
153 for (int i = 0; i < 50; i++) {
154 Assertions.assertEquals(sampler1.sample(), sampler2.sample());
155 }
156 }
157
158
159
160
161
162
163 @Test
164 void testSampleWithProbabilityAtLastItem() {
165
166
167 final UniformRandomProvider dummyRng = new UniformRandomProvider() {
168 private int count;
169
170 @Override
171 public long nextLong() {
172 return 0;
173 }
174
175 @Override
176 public double nextDouble() {
177
178 return (count++ == 0) ? 0 : 1.0;
179 }
180 };
181
182 final List<Double> items = Arrays.asList(1d, 2d);
183 final DiscreteProbabilityCollectionSampler<Double> sampler =
184 new DiscreteProbabilityCollectionSampler<>(dummyRng,
185 items,
186 new double[] {0.5, 0.5});
187 final Double item1 = sampler.sample();
188 final Double item2 = sampler.sample();
189
190 Assertions.assertTrue(items.contains(item1), "Sample item1 is not from the list");
191 Assertions.assertTrue(items.contains(item2), "Sample item2 is not from the list");
192
193 Assertions.assertNotSame(item1, item2, "Item1 and 2 should be different");
194 }
195
196
197
198
199 @Test
200 void testSharedStateSampler() {
201 final UniformRandomProvider rng1 = RandomAssert.seededRNG();
202 final UniformRandomProvider rng2 = RandomAssert.seededRNG();
203 final List<Double> items = Arrays.asList(1d, 2d, 3d, 4d);
204 final DiscreteProbabilityCollectionSampler<Double> sampler1 =
205 new DiscreteProbabilityCollectionSampler<>(rng1,
206 items,
207 new double[] {0.1, 0.2, 0.3, 0.4});
208 final DiscreteProbabilityCollectionSampler<Double> sampler2 = sampler1.withUniformRandomProvider(rng2);
209 RandomAssert.assertProduceSameSequence(sampler1, sampler2);
210 }
211 }