View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.rng.sampling.distribution;
18  
19  import java.util.Arrays;
20  import org.apache.commons.math3.linear.Array2DRowRealMatrix;
21  import org.apache.commons.math3.stat.correlation.Covariance;
22  import org.apache.commons.math3.stat.descriptive.moment.Mean;
23  import org.apache.commons.rng.UniformRandomProvider;
24  import org.apache.commons.rng.core.source64.SplitMix64;
25  import org.apache.commons.rng.sampling.RandomAssert;
26  import org.apache.commons.rng.simple.RandomSource;
27  import org.junit.jupiter.api.Assertions;
28  import org.junit.jupiter.api.Test;
29  
30  /**
31   * Test for {@link DirichletSampler}.
32   */
33  class DirichletSamplerTest {
34      @Test
35      void testDistributionThrowsWithInvalidNumberOfCategories() {
36          final UniformRandomProvider rng = RandomAssert.seededRNG();
37          Assertions.assertThrows(IllegalArgumentException.class,
38              () -> DirichletSampler.of(rng, 1.0));
39      }
40  
41      @Test
42      void testDistributionThrowsWithZeroConcentration() {
43          final UniformRandomProvider rng = RandomAssert.seededRNG();
44          Assertions.assertThrows(IllegalArgumentException.class,
45              () -> DirichletSampler.of(rng, 1.0, 0.0));
46      }
47  
48      @Test
49      void testDistributionThrowsWithNaNConcentration() {
50          final UniformRandomProvider rng = RandomAssert.seededRNG();
51          Assertions.assertThrows(IllegalArgumentException.class,
52              () -> DirichletSampler.of(rng, 1.0, Double.NaN));
53      }
54  
55      @Test
56      void testDistributionThrowsWithInfiniteConcentration() {
57          final UniformRandomProvider rng = RandomAssert.seededRNG();
58          Assertions.assertThrows(IllegalArgumentException.class,
59              () -> DirichletSampler.of(rng, 1.0, Double.POSITIVE_INFINITY));
60      }
61  
62      @Test
63      void testSymmetricDistributionThrowsWithInvalidNumberOfCategories() {
64          final UniformRandomProvider rng = RandomAssert.seededRNG();
65          Assertions.assertThrows(IllegalArgumentException.class,
66                  () -> DirichletSampler.symmetric(rng, 1, 1.0));
67      }
68  
69      @Test
70      void testSymmetricDistributionThrowsWithZeroConcentration() {
71          final UniformRandomProvider rng = RandomAssert.seededRNG();
72          Assertions.assertThrows(IllegalArgumentException.class,
73              () -> DirichletSampler.symmetric(rng, 2, 0.0));
74      }
75  
76      @Test
77      void testSymmetricDistributionThrowsWithNaNConcentration() {
78          final UniformRandomProvider rng = RandomAssert.seededRNG();
79          Assertions.assertThrows(IllegalArgumentException.class,
80              () -> DirichletSampler.symmetric(rng, 2, Double.NaN));
81      }
82  
83      @Test
84      void testSymmetricDistributionThrowsWithInfiniteConcentration() {
85          final UniformRandomProvider rng = RandomAssert.seededRNG();
86          Assertions.assertThrows(IllegalArgumentException.class,
87              () -> DirichletSampler.symmetric(rng, 2, Double.POSITIVE_INFINITY));
88      }
89  
90      /**
91       * Create condition so that all samples are zero and it is impossible to normalise the
92       * samples to sum to 1. These should be ignored and the sample is repeated until
93       * normalisation is possible.
94       */
95      @Test
96      void testInvalidSampleIsIgnored() {
97          // An RNG implementation which should create zero samples from the underlying
98          // exponential sampler for an initial sequence.
99          final UniformRandomProvider rng = new SplitMix64(0L) {
100             private int i;
101 
102             @Override
103             public long next() {
104                 return i++ < 10 ? 0L : super.next();
105             }
106         };
107 
108         // Alpha=1 will use an exponential sampler
109         final DirichletSampler sampler = DirichletSampler.symmetric(rng, 2, 1.0);
110         assertSample(2, sampler.sample());
111     }
112 
113     @Test
114     void testSharedStateSampler() {
115         final RandomSource randomSource = RandomSource.XO_RO_SHI_RO_128_PP;
116         final byte[] seed = randomSource.createSeed();
117         final UniformRandomProvider rng1 = randomSource.create(seed);
118         final UniformRandomProvider rng2 = randomSource.create(seed);
119         final DirichletSampler sampler1 = DirichletSampler.of(rng1, 1, 2, 3);
120         final DirichletSampler sampler2 = sampler1.withUniformRandomProvider(rng2);
121         RandomAssert.assertProduceSameSequence(sampler1, sampler2);
122     }
123 
124     @Test
125     void testSharedStateSamplerForSymmetricCase() {
126         final RandomSource randomSource = RandomSource.XO_RO_SHI_RO_128_PP;
127         final byte[] seed = randomSource.createSeed();
128         final UniformRandomProvider rng1 = randomSource.create(seed);
129         final UniformRandomProvider rng2 = randomSource.create(seed);
130         final DirichletSampler sampler1 = DirichletSampler.symmetric(rng1, 2, 1.5);
131         final DirichletSampler sampler2 = sampler1.withUniformRandomProvider(rng2);
132         RandomAssert.assertProduceSameSequence(sampler1, sampler2);
133     }
134 
135     @Test
136     void testSymmetricCaseMatchesGeneralCase() {
137         final RandomSource randomSource = RandomSource.XO_RO_SHI_RO_128_PP;
138         final byte[] seed = randomSource.createSeed();
139         final UniformRandomProvider rng1 = randomSource.create(seed);
140         final UniformRandomProvider rng2 = randomSource.create(seed);
141         final int k = 3;
142         final double[] alphas = new double[k];
143         for (final double alpha : new double[] {0.5, 1.0, 1.5}) {
144             Arrays.fill(alphas, alpha);
145             final DirichletSampler sampler1 = DirichletSampler.symmetric(rng1, k, alpha);
146             final DirichletSampler sampler2 = DirichletSampler.of(rng2, alphas);
147             RandomAssert.assertProduceSameSequence(sampler1, sampler2);
148         }
149     }
150 
151     /**
152      * Test the toString method. This is added to ensure coverage.
153      */
154     @Test
155     void testToString() {
156         final UniformRandomProvider rng = RandomAssert.seededRNG();
157         final DirichletSampler sampler1 = DirichletSampler.symmetric(rng, 2, 1.0);
158         final DirichletSampler sampler2 = DirichletSampler.of(rng, 0.5, 1, 1.5);
159         Assertions.assertTrue(sampler1.toString().toLowerCase().contains("dirichlet"));
160         Assertions.assertTrue(sampler2.toString().toLowerCase().contains("dirichlet"));
161     }
162 
163     @Test
164     void testSampling1() {
165         assertSamples(1, 2, 3);
166     }
167 
168     @Test
169     void testSampling2() {
170         assertSamples(1, 1, 1);
171     }
172 
173     @Test
174     void testSampling3() {
175         assertSamples(0.5, 1, 1.5);
176     }
177 
178     @Test
179     void testSampling4() {
180         assertSamples(1, 3);
181     }
182 
183     @Test
184     void testSampling5() {
185         assertSamples(1, 2, 3, 4);
186     }
187 
188     /**
189      * Assert samples from the distribution. The variates are tested against the expected
190      * mean and covariance for the given concentration parameters.
191      *
192      * @param alpha Concentration parameters.
193      */
194     private static void assertSamples(double... alpha) {
195         // No fixed seed. Failed tests will be repeated by the JUnit test runner.
196         final UniformRandomProvider rng = RandomAssert.createRNG();
197         final DirichletSampler sampler = DirichletSampler.of(rng, alpha);
198         final int k = alpha.length;
199         final double[][] samples = new double[100000][];
200         for (int i = 0; i < samples.length; i++) {
201             final double[] x = sampler.sample();
202             assertSample(k, x);
203             samples[i] = x;
204         }
205 
206         // Computation of moments:
207         // https://en.wikipedia.org/wiki/Dirichlet_distribution#Moments
208 
209         // Compute the sum of the concentration parameters: alpha0
210         double alpha0 = 0;
211         for (int i = 0; i < k; i++) {
212             alpha0 += alpha[i];
213         }
214 
215         // Use a moderate tolerance.
216         // Differences are usually observed in the 3rd significant figure.
217         final double relativeTolerance = 5e-2;
218 
219         // Mean = alpha[i] / alpha0
220         final double[] means = getColumnMeans(samples);
221         for (int i = 0; i < k; i++) {
222             final double mean = alpha[i] / alpha0;
223             Assertions.assertEquals(mean, means[i], mean * relativeTolerance, "Mean");
224         }
225 
226         // Variance = alpha_i (alpha_0 - alpha_i) / alpha_0^2 (alpha_0 + 1)
227         // Covariance = -alpha_i * alpha_j / alpha_0^2 (alpha_0 + 1)
228         final double[][] covars = getCovariance(samples);
229         final double denom = alpha0 * alpha0 * (alpha0 + 1);
230         for (int i = 0; i < k; i++) {
231             final double var = alpha[i] * (alpha0 - alpha[i]) / denom;
232             Assertions.assertEquals(var, covars[i][i], var * relativeTolerance, "Variance");
233             for (int j = i + 1; j < k; j++) {
234                 final double covar = -alpha[i] * alpha[j] / denom;
235                 Assertions.assertEquals(covar, covars[i][j], Math.abs(covar) * relativeTolerance, "Covariance");
236             }
237         }
238     }
239 
240     /**
241      * Assert the sample has the correct length and sums to 1.
242      *
243      * @param k Expected number of categories.
244      * @param x Sample.
245      */
246     private static void assertSample(int k, double[] x) {
247         Assertions.assertEquals(k, x.length, "Number of categories");
248         // There are always at least 2 categories
249         double sum = x[0] + x[1];
250         // Sum the rest
251         for (int i = 2; i < x.length; i++) {
252             sum += x[i];
253         }
254         Assertions.assertEquals(1.0, sum, 1e-10, "Invalid sum");
255     }
256 
257     /**
258      * Gets the column means. This is done using the same method as the means in the
259      * Apache Commons Math Covariance class by using the Mean class.
260      *
261      * @param data the data
262      * @return the column means
263      */
264     private static double[] getColumnMeans(double[][] data) {
265         final Array2DRowRealMatrix m = new Array2DRowRealMatrix(data, false);
266         final Mean mean = new Mean();
267         final double[] means = new double[m.getColumnDimension()];
268         for (int i = 0; i < means.length; i++) {
269             means[i] = mean.evaluate(m.getColumn(i));
270         }
271         return means;
272     }
273 
274     /**
275      * Gets the covariance.
276      *
277      * @param data the data
278      * @return the covariance
279      */
280     private static double[][] getCovariance(double[][] data) {
281         final Array2DRowRealMatrix m = new Array2DRowRealMatrix(data, false);
282         return new Covariance(m).getCovarianceMatrix().getData();
283     }
284 }