1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.sampling.distribution;
18
19 import java.util.Arrays;
20 import org.apache.commons.math3.linear.Array2DRowRealMatrix;
21 import org.apache.commons.math3.stat.correlation.Covariance;
22 import org.apache.commons.math3.stat.descriptive.moment.Mean;
23 import org.apache.commons.rng.UniformRandomProvider;
24 import org.apache.commons.rng.core.source64.SplitMix64;
25 import org.apache.commons.rng.sampling.RandomAssert;
26 import org.apache.commons.rng.simple.RandomSource;
27 import org.junit.jupiter.api.Assertions;
28 import org.junit.jupiter.api.Test;
29
30
31
32
33 class DirichletSamplerTest {
34 @Test
35 void testDistributionThrowsWithInvalidNumberOfCategories() {
36 final UniformRandomProvider rng = RandomAssert.seededRNG();
37 Assertions.assertThrows(IllegalArgumentException.class,
38 () -> DirichletSampler.of(rng, 1.0));
39 }
40
41 @Test
42 void testDistributionThrowsWithZeroConcentration() {
43 final UniformRandomProvider rng = RandomAssert.seededRNG();
44 Assertions.assertThrows(IllegalArgumentException.class,
45 () -> DirichletSampler.of(rng, 1.0, 0.0));
46 }
47
48 @Test
49 void testDistributionThrowsWithNaNConcentration() {
50 final UniformRandomProvider rng = RandomAssert.seededRNG();
51 Assertions.assertThrows(IllegalArgumentException.class,
52 () -> DirichletSampler.of(rng, 1.0, Double.NaN));
53 }
54
55 @Test
56 void testDistributionThrowsWithInfiniteConcentration() {
57 final UniformRandomProvider rng = RandomAssert.seededRNG();
58 Assertions.assertThrows(IllegalArgumentException.class,
59 () -> DirichletSampler.of(rng, 1.0, Double.POSITIVE_INFINITY));
60 }
61
62 @Test
63 void testSymmetricDistributionThrowsWithInvalidNumberOfCategories() {
64 final UniformRandomProvider rng = RandomAssert.seededRNG();
65 Assertions.assertThrows(IllegalArgumentException.class,
66 () -> DirichletSampler.symmetric(rng, 1, 1.0));
67 }
68
69 @Test
70 void testSymmetricDistributionThrowsWithZeroConcentration() {
71 final UniformRandomProvider rng = RandomAssert.seededRNG();
72 Assertions.assertThrows(IllegalArgumentException.class,
73 () -> DirichletSampler.symmetric(rng, 2, 0.0));
74 }
75
76 @Test
77 void testSymmetricDistributionThrowsWithNaNConcentration() {
78 final UniformRandomProvider rng = RandomAssert.seededRNG();
79 Assertions.assertThrows(IllegalArgumentException.class,
80 () -> DirichletSampler.symmetric(rng, 2, Double.NaN));
81 }
82
83 @Test
84 void testSymmetricDistributionThrowsWithInfiniteConcentration() {
85 final UniformRandomProvider rng = RandomAssert.seededRNG();
86 Assertions.assertThrows(IllegalArgumentException.class,
87 () -> DirichletSampler.symmetric(rng, 2, Double.POSITIVE_INFINITY));
88 }
89
90
91
92
93
94
95 @Test
96 void testInvalidSampleIsIgnored() {
97
98
99 final UniformRandomProvider rng = new SplitMix64(0L) {
100 private int i;
101
102 @Override
103 public long next() {
104 return i++ < 10 ? 0L : super.next();
105 }
106 };
107
108
109 final DirichletSampler sampler = DirichletSampler.symmetric(rng, 2, 1.0);
110 assertSample(2, sampler.sample());
111 }
112
113 @Test
114 void testSharedStateSampler() {
115 final RandomSource randomSource = RandomSource.XO_RO_SHI_RO_128_PP;
116 final byte[] seed = randomSource.createSeed();
117 final UniformRandomProvider rng1 = randomSource.create(seed);
118 final UniformRandomProvider rng2 = randomSource.create(seed);
119 final DirichletSampler sampler1 = DirichletSampler.of(rng1, 1, 2, 3);
120 final DirichletSampler sampler2 = sampler1.withUniformRandomProvider(rng2);
121 RandomAssert.assertProduceSameSequence(sampler1, sampler2);
122 }
123
124 @Test
125 void testSharedStateSamplerForSymmetricCase() {
126 final RandomSource randomSource = RandomSource.XO_RO_SHI_RO_128_PP;
127 final byte[] seed = randomSource.createSeed();
128 final UniformRandomProvider rng1 = randomSource.create(seed);
129 final UniformRandomProvider rng2 = randomSource.create(seed);
130 final DirichletSampler sampler1 = DirichletSampler.symmetric(rng1, 2, 1.5);
131 final DirichletSampler sampler2 = sampler1.withUniformRandomProvider(rng2);
132 RandomAssert.assertProduceSameSequence(sampler1, sampler2);
133 }
134
135 @Test
136 void testSymmetricCaseMatchesGeneralCase() {
137 final RandomSource randomSource = RandomSource.XO_RO_SHI_RO_128_PP;
138 final byte[] seed = randomSource.createSeed();
139 final UniformRandomProvider rng1 = randomSource.create(seed);
140 final UniformRandomProvider rng2 = randomSource.create(seed);
141 final int k = 3;
142 final double[] alphas = new double[k];
143 for (final double alpha : new double[] {0.5, 1.0, 1.5}) {
144 Arrays.fill(alphas, alpha);
145 final DirichletSampler sampler1 = DirichletSampler.symmetric(rng1, k, alpha);
146 final DirichletSampler sampler2 = DirichletSampler.of(rng2, alphas);
147 RandomAssert.assertProduceSameSequence(sampler1, sampler2);
148 }
149 }
150
151
152
153
154 @Test
155 void testToString() {
156 final UniformRandomProvider rng = RandomAssert.seededRNG();
157 final DirichletSampler sampler1 = DirichletSampler.symmetric(rng, 2, 1.0);
158 final DirichletSampler sampler2 = DirichletSampler.of(rng, 0.5, 1, 1.5);
159 Assertions.assertTrue(sampler1.toString().toLowerCase().contains("dirichlet"));
160 Assertions.assertTrue(sampler2.toString().toLowerCase().contains("dirichlet"));
161 }
162
163 @Test
164 void testSampling1() {
165 assertSamples(1, 2, 3);
166 }
167
168 @Test
169 void testSampling2() {
170 assertSamples(1, 1, 1);
171 }
172
173 @Test
174 void testSampling3() {
175 assertSamples(0.5, 1, 1.5);
176 }
177
178 @Test
179 void testSampling4() {
180 assertSamples(1, 3);
181 }
182
183 @Test
184 void testSampling5() {
185 assertSamples(1, 2, 3, 4);
186 }
187
188
189
190
191
192
193
194 private static void assertSamples(double... alpha) {
195
196 final UniformRandomProvider rng = RandomAssert.createRNG();
197 final DirichletSampler sampler = DirichletSampler.of(rng, alpha);
198 final int k = alpha.length;
199 final double[][] samples = new double[100000][];
200 for (int i = 0; i < samples.length; i++) {
201 final double[] x = sampler.sample();
202 assertSample(k, x);
203 samples[i] = x;
204 }
205
206
207
208
209
210 double alpha0 = 0;
211 for (int i = 0; i < k; i++) {
212 alpha0 += alpha[i];
213 }
214
215
216
217 final double relativeTolerance = 5e-2;
218
219
220 final double[] means = getColumnMeans(samples);
221 for (int i = 0; i < k; i++) {
222 final double mean = alpha[i] / alpha0;
223 Assertions.assertEquals(mean, means[i], mean * relativeTolerance, "Mean");
224 }
225
226
227
228 final double[][] covars = getCovariance(samples);
229 final double denom = alpha0 * alpha0 * (alpha0 + 1);
230 for (int i = 0; i < k; i++) {
231 final double var = alpha[i] * (alpha0 - alpha[i]) / denom;
232 Assertions.assertEquals(var, covars[i][i], var * relativeTolerance, "Variance");
233 for (int j = i + 1; j < k; j++) {
234 final double covar = -alpha[i] * alpha[j] / denom;
235 Assertions.assertEquals(covar, covars[i][j], Math.abs(covar) * relativeTolerance, "Covariance");
236 }
237 }
238 }
239
240
241
242
243
244
245
246 private static void assertSample(int k, double[] x) {
247 Assertions.assertEquals(k, x.length, "Number of categories");
248
249 double sum = x[0] + x[1];
250
251 for (int i = 2; i < x.length; i++) {
252 sum += x[i];
253 }
254 Assertions.assertEquals(1.0, sum, 1e-10, "Invalid sum");
255 }
256
257
258
259
260
261
262
263
264 private static double[] getColumnMeans(double[][] data) {
265 final Array2DRowRealMatrix m = new Array2DRowRealMatrix(data, false);
266 final Mean mean = new Mean();
267 final double[] means = new double[m.getColumnDimension()];
268 for (int i = 0; i < means.length; i++) {
269 means[i] = mean.evaluate(m.getColumn(i));
270 }
271 return means;
272 }
273
274
275
276
277
278
279
280 private static double[][] getCovariance(double[][] data) {
281 final Array2DRowRealMatrix m = new Array2DRowRealMatrix(data, false);
282 return new Covariance(m).getCovarianceMatrix().getData();
283 }
284 }